Skip to content
Snippets Groups Projects
server.py 20.4 KiB
Newer Older
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Amber Brown's avatar
Amber Brown committed
import json
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
    AnyStr,
    Callable,
    Dict,
    Iterable,
    MutableMapping,
    Optional,
    Tuple,
Amber Brown's avatar
Amber Brown committed
import attr
from typing_extensions import Deque
from zope.interface import implementer
Amber Brown's avatar
Amber Brown committed

from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
    IHostnameResolver,
    IProtocol,
    IPullProducer,
    IPushProducer,
    IReactorPluggableNameResolver,
Amber Brown's avatar
Amber Brown committed
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site

from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import Clock
Amber Brown's avatar
Amber Brown committed

from tests.utils import setup_test_homeserver as _sth

logger = logging.getLogger(__name__)

class TimedOutException(Exception):
    """
    A web query timed out.
    """


class FakeChannel:
    """
    A fake Twisted Web Channel (the part that interfaces with the
    wire).
    """

    site = attr.ib(type=Union[Site, "FakeSite"])
    _reactor = attr.ib()
    result = attr.ib(type=dict, default=attr.Factory(dict))
    _ip = attr.ib(type=str, default="127.0.0.1")
    _producer: Optional[Union[IPullProducer, IPushProducer]] = None
        return json.loads(self.text_body)

    @property
    def text_body(self) -> str:
        """The body of the result, utf-8-decoded.

        Raises an exception if the request has not yet completed.
        """
        if not self.is_finished:
            raise Exception("Request not yet completed")
        return self.result["body"].decode("utf8")

    def is_finished(self) -> bool:
        """check if the response has been completely received"""
        return self.result.get("done", False)

    @property
    def code(self):
        if not self.result:
            raise Exception("No result yet.")
        return int(self.result["code"])
        if not self.result:
            raise Exception("No result yet.")
        h = Headers()
        for i in self.result["headers"]:
            h.addRawHeader(*i)
        return h

    def writeHeaders(self, version, code, reason, headers):
        self.result["version"] = version
        self.result["code"] = code
        self.result["reason"] = reason
        self.result["headers"] = headers

    def write(self, content):
        assert isinstance(content, bytes), "Should be bytes! " + repr(content)

        if "body" not in self.result:
            self.result["body"] = b""

        self.result["body"] += content

Amber Brown's avatar
Amber Brown committed
    def registerProducer(self, producer, streaming):
        self._producer = producer
        self.producerStreaming = streaming

        def _produce():
            if self._producer:
                self._producer.resumeProducing()
                self._reactor.callLater(0.1, _produce)

        if not streaming:
            self._reactor.callLater(0.0, _produce)
Amber Brown's avatar
Amber Brown committed

    def unregisterProducer(self):
        if self._producer is None:
            return

        self._producer = None

    def requestDone(self, _self):
        self.result["done"] = True

    def getPeer(self):
Erik Johnston's avatar
Erik Johnston committed
        # We give an address so that getClientIP returns a non null entry,
        # causing us to record the MAU
        return address.IPv4Address("TCP", self._ip, 3423)
        # this is called by Request.__init__ to configure Request.host.
        return address.IPv4Address("TCP", "127.0.0.1", 8888)

    def isSecure(self):
        return False
    def await_result(self, timeout_ms: int = 1000) -> None:
        """
        Wait until the request is finished.
        """
        end_time = self._reactor.seconds() + timeout_ms / 1000.0
        while not self.is_finished():
            # If there's a producer, tell it to resume producing so we get content
            if self._producer:
                self._producer.resumeProducing()

                raise TimedOutException("Timed out waiting for request to finish.")

            self._reactor.advance(0.1)

    def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
        """Process the contents of any Set-Cookie headers in the response

        Any cookines found are added to the given dict
        """
        headers = self.headers.getRawHeaders("Set-Cookie")
        if not headers:
            return

        for h in headers:
            parts = h.split(";")
            k, v = parts[0].split("=", maxsplit=1)
            cookies[k] = v


class FakeSite:
    """
    A fake Twisted Web Site, with mocks of the extra things that
    Synapse adds.
    """

    server_version_string = b"1"
    site_tag = "test"
    access_logger = logging.getLogger("synapse.access.http.fake")
    def __init__(self, resource: IResource, reactor: IReactorTime):
        """

        Args:
            resource: the resource to be used for rendering all requests
        """
        self._resource = resource

    def getResourceFor(self, request):
        return self._resource

def make_request(
    site: Union[Site, FakeSite],
    method: Union[bytes, str],
    path: Union[bytes, str],
    content: Union[bytes, str, JsonDict] = b"",
    access_token: Optional[str] = None,
    request: Type[Request] = SynapseRequest,
    shorthand: bool = True,
    federation_auth_origin: Optional[bytes] = None,
    content_is_form: bool = False,
    await_result: bool = True,
    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
    client_ip: str = "127.0.0.1",
    Make a web request using the given method, path and content, and render it

    Returns the fake Channel object which records the response to the request.
        site: The twisted Site to use to render the request
        method: The HTTP request method ("verb").
        path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
        content: The body of the request. JSON-encoded, if a str of bytes.
        access_token: The access token to add as authorization for the request.
        request: The request class to create.
        shorthand: Whether to try and be helpful and prefix the given URL
            with the usual REST API path, if it doesn't contain it.
        federation_auth_origin: if set to not-None, we will add a fake
            Authorization header pretenting to be the given server name.
        content_is_form: Whether the content is URL encoded form data. Adds the
            'Content-Type': 'application/x-www-form-urlencoded' header.
        await_result: whether to wait for the request to complete rendering. If true,
             will pump the reactor until the the renderer tells the channel the request
             is finished.
        custom_headers: (name, value) pairs to add as request headers
        client_ip: The IP to use as the requesting IP. Useful for testing
            ratelimiting.

    if not isinstance(method, bytes):
Amber Brown's avatar
Amber Brown committed
        method = method.encode("ascii")

    if not isinstance(path, bytes):
Amber Brown's avatar
Amber Brown committed
        path = path.encode("ascii")
    # Decorate it to be the full path, if we're using shorthand
    if (
        shorthand
        and not path.startswith(b"/_matrix")
        and not path.startswith(b"/_synapse")
    ):
        if path.startswith(b"/"):
            path = path[1:]
        path = b"/_matrix/client/r0/" + path

    if not path.startswith(b"/"):
        path = b"/" + path

    if isinstance(content, dict):
        content = json.dumps(content).encode("utf8")
Amber Brown's avatar
Amber Brown committed
        content = content.encode("utf8")
    channel = FakeChannel(site, reactor, ip=client_ip)
    req = request(channel, site)
    req.content = BytesIO(content)
    # Twisted expects to be at the end of the content when parsing the request.
    req.content.seek(SEEK_END)
Erik Johnston's avatar
Erik Johnston committed

    if access_token:
        req.requestHeaders.addRawHeader(
Amber Brown's avatar
Amber Brown committed
            b"Authorization", b"Bearer " + access_token.encode("ascii")
Erik Johnston's avatar
Erik Johnston committed

    if federation_auth_origin is not None:
        req.requestHeaders.addRawHeader(
            b"Authorization",
            b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
        if content_is_form:
            req.requestHeaders.addRawHeader(
                b"Content-Type", b"application/x-www-form-urlencoded"
            )
        else:
            # Assume the body is JSON
            req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
    if custom_headers:
        for k, v in custom_headers:
            req.requestHeaders.addRawHeader(k, v)

    req.parseCookies()
    req.requestReceived(method, path, b"1.1")

    if await_result:
        channel.await_result()

@implementer(IReactorPluggableNameResolver)
class ThreadedMemoryReactorClock(MemoryReactorClock):
    """
    A MemoryReactorClock that supports callFromThread.
    """
black's avatar
black committed

        self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
        self.lookups: Dict[str, str] = {}
        self._thread_callbacks: Deque[Callable[[], None]] = deque()

        lookups = self.lookups

        @implementer(IResolverSimple)
        class FakeResolver:
            def getHostByName(self, name, timeout=None):
                if name not in lookups:
                    return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                return succeed(lookups[name])

        self.nameResolver = SimpleResolverComplexifier(FakeResolver())
    def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
        raise NotImplementedError()

Amber Brown's avatar
Amber Brown committed
    def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
        p = udp.Port(port, protocol, interface, maxPacketSize, self)
        p.startListening()
        self._udp.append(p)
        return p

    def callFromThread(self, callback, *args, **kwargs):
        """
        Make the callback fire in the next reactor iteration.
        """
        cb = lambda: callback(*args, **kwargs)
        # it's not safe to call callLater() here, so we append the callback to a
        # separate queue.
        self._thread_callbacks.append(cb)
    def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
        """Add a callback that will be invoked when we receive a connection
        attempt to the given IP/port using `connectTCP`.

        Note that the callback gets run before we return the connection to the
        client, which means callbacks cannot block while waiting for writes.
        """
        self._tcp_callbacks[(host, port)] = callback

    def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
        """Fake L{IReactorTCP.connectTCP}."""

        conn = super().connectTCP(
            host, port, factory, timeout=timeout, bindAddress=None
        )

        callback = self._tcp_callbacks.get((host, port))
        if callback:
            callback()

        return conn

    def advance(self, amount):
        # first advance our reactor's time, and run any "callLater" callbacks that
        # makes ready
        super().advance(amount)

        # now run any "callFromThread" callbacks
        while True:
            try:
                callback = self._thread_callbacks.popleft()
            except IndexError:
                break
            callback()

            # check for more "callLater" callbacks added by the thread callback
            # This isn't required in a regular reactor, but it ends up meaning that
            # our database queries can complete in a single call to `advance` [1] which
            # simplifies tests.
            #
            # [1]: we replace the threadpool backing the db connection pool with a
            # mock ThreadPool which doesn't really use threads; but we still use
            # reactor.callFromThread to feed results back from the db functions to the
            # main thread.
            super().advance(0)


class ThreadPool:
    """
    Threadless thread pool.
    """

    def __init__(self, reactor):
        self._reactor = reactor

    def start(self):
        pass

    def stop(self):
        pass

    def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
        def _(res):
            if isinstance(res, Failure):
                onResult(False, res)
            else:
                onResult(True, res)

        d = Deferred()
        d.addCallback(lambda x: function(*args, **kwargs))
        d.addBoth(_)
        self._reactor.callLater(0, d.callback, True)
        return d

def setup_test_homeserver(cleanup_func, *args, **kwargs):
    """
    Set up a synchronous test server, driven by the reactor used by
    the homeserver.
    """
    server = _sth(cleanup_func, *args, **kwargs)
    # Make the thread pool synchronous.
    clock = server.get_clock()

    for database in server.get_datastores().databases:
        pool = database._db_pool

        def runWithConnection(func, *args, **kwargs):
            return threads.deferToThreadPool(
                pool._reactor,
                pool.threadpool,
                pool._runWithConnection,
                func,
                *args,
            )

        def runInteraction(interaction, *args, **kwargs):
            return threads.deferToThreadPool(
                pool._reactor,
                pool.threadpool,
                pool._runInteraction,
                interaction,
                *args,
        pool.runWithConnection = runWithConnection
        pool.runInteraction = runInteraction
        pool.threadpool = ThreadPool(clock._reactor)
    # We've just changed the Databases to run DB transactions on the same
    # thread, so we need to disable the dedicated thread behaviour.
    server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False

    return server
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
    clock = ThreadedMemoryReactorClock()
    hs_clock = Clock(clock)
@attr.s(cmp=False)
class FakeTransport:
    """
    A twisted.internet.interfaces.ITransport implementation which sends all its data
    straight into an IProtocol object: it exists to connect two IProtocols together.

    To use it, instantiate it with the receiving IProtocol, and then pass it to the
    sending IProtocol's makeConnection method:

        server = HTTPChannel()
        client.makeConnection(FakeTransport(server, self.reactor))

    If you want bidirectional communication, you'll need two instances.
    """

    other = attr.ib()
    """The Protocol object which will receive any data written to this transport.

    :type: twisted.internet.interfaces.IProtocol
    """

    _reactor = attr.ib()
    """Test reactor

    :type: twisted.internet.interfaces.IReactorTime
    """

    _protocol = attr.ib(default=None)
    """The Protocol which is producing data for this transport. Optional, but if set
    will get called back for connectionLost() notifications etc.
    """

    _peer_address: Optional[IAddress] = attr.ib(default=None)
    """The value to be returend by getPeer"""

Amber Brown's avatar
Amber Brown committed
    buffer = attr.ib(default=b"")
    producer = attr.ib(default=None)
    autoflush = attr.ib(default=True)
        return self._peer_address
    def loseConnection(self, reason=None):
        if not self.disconnecting:
            logger.info("FakeTransport: loseConnection(%s)", reason)
            self.disconnecting = True
            if self._protocol:
                self._protocol.connectionLost(reason)

            # if we still have data to write, delay until that is done
            if self.buffer:
                logger.info(
                    "FakeTransport: Delaying disconnect until buffer is flushed"
                )
            else:
                self.connected = False
        logger.info("FakeTransport: abortConnection()")

        if not self.disconnecting:
            self.disconnecting = True
            if self._protocol:
                self._protocol.connectionLost(None)

        self.disconnected = True
        if not self.producer:
            return

    def resumeProducing(self):
        if not self.producer:
            return
        self.producer.resumeProducing()

    def unregisterProducer(self):
        if not self.producer:
            return

        self.producer = None

    def registerProducer(self, producer, streaming):
        self.producer = producer
        self.producerStreaming = streaming

        def _produce():
            if not self.producer:
                # we've been unregistered
                return
            # some implementations of IProducer (for example, FileSender)
            # don't return a deferred.
            d = maybeDeferred(self.producer.resumeProducing)
            d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))

        if not streaming:
            self._reactor.callLater(0.0, _produce)

    def write(self, byt):
        if self.disconnecting:
            raise Exception("Writing to disconnecting FakeTransport")

        # always actually do the write asynchronously. Some protocols (notably the
        # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
        # still doing a write. Doing a callLater here breaks the cycle.
        if self.autoflush:
            self._reactor.callLater(0.0, self.flush)

    def writeSequence(self, seq):
        for x in seq:
            self.write(x)

    def flush(self, maxbytes=None):
        if not self.buffer:
            # nothing to do. Don't write empty buffers: it upsets the
            # TLSMemoryBIOProtocol
            return

        if self.disconnected:
            return

        if maxbytes is not None:
            to_write = self.buffer[:maxbytes]
        else:
            to_write = self.buffer

        logger.info("%s->%s: %s", self._protocol, self.other, to_write)

        try:
            self.other.dataReceived(to_write)
        except Exception as e:
            logger.exception("Exception writing to protocol: %s", e)
        self.buffer = self.buffer[len(to_write) :]
        if self.buffer and self.autoflush:
            self._reactor.callLater(0.0, self.flush)
        if not self.buffer and self.disconnecting:
            logger.info("FakeTransport: Buffer now empty, completing disconnect")
            self.disconnected = True

def connect_client(
    reactor: ThreadedMemoryReactorClock, client_id: int
) -> Tuple[IProtocol, AccumulatingProtocol]:
    """
    Connect a client to a fake TCP transport.

    Args:
        reactor
        factory: The connecting factory to build.
    """
    factory = reactor.tcpClients.pop(client_id)[2]
    client = factory.buildProtocol(None)
    server = AccumulatingProtocol()
    server.makeConnection(FakeTransport(client, reactor))
    client.makeConnection(FakeTransport(server, reactor))

    return client, server