Newer
Older

Matthew Hodgson
committed
# -*- coding: utf-8 -*-

Richard van der Hoff
committed
# Copyright 2018-2019 New Vector Ltd

Matthew Hodgson
committed
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import getcallargs
from mock import Mock, patch
from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.

Richard van der Hoff
committed
#
# When running under postgres, we first create a base database with the name
# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we
# create another unique database, using the base database as a template.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)

Richard van der Hoff
committed
# the dbname we will connect to in order to create the base database.
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"

Richard van der Hoff
committed
def setupdb():
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:

Richard van der Hoff
committed
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,

Richard van der Hoff
committed
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
cur.close()
db_conn.close()
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
prepare_database(db_conn, db_engine, None)
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,

Richard van der Hoff
committed
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.close()
db_conn.close()
atexit.register(_cleanup)

Amber Brown
committed
def default_config(name, parse=False):
"""
Create a reasonable test config.
"""
config_dict = {
"server_name": name,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",

Amber Brown
committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"event_cache_size": 1,
"enable_registration": True,
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"trusted_third_party_id_servers": [],
"room_invite_state_types": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
"block_non_admin_invites": False,
"federation_domain_whitelist": None,
"filter_timeline_limit": 5000,
"user_directory_search_all_users": False,
"user_consent_server_notice_content": None,
"block_events_without_consent_error": None,
"user_consent_at_registration": False,
"user_consent_policy_name": "Privacy Policy",
"media_storage_providers": [],
"autocreate_auto_join_rooms": True,
"auto_join_rooms": [],
"limit_usage_by_mau": False,
"hs_disabled": False,
"hs_disabled_message": "",
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
"mau_limits_reserved_threepids": [],
"admin_contact": None,
"rc_message": {"per_second": 10000, "burst_count": 10000},

Amber Brown
committed
"rc_registration": {"per_second": 10000, "burst_count": 10000},
"rc_login": {
"address": {"per_second": 10000, "burst_count": 10000},
"account": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 10000, "burst_count": 10000},
},
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
"old_signing_keys": {},
"tls_fingerprints": [],
"use_frozen_dicts": False,
# We need a sane default_room_version, otherwise attempts to create
# rooms will fail.
"default_room_version": DEFAULT_ROOM_VERSION,

Amber Brown
committed
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,

Amber Brown
committed
if parse:
config = HomeServerConfig()

Richard van der Hoff
committed
config.parse_config_dict(config_dict, "", "")

Amber Brown
committed
return config
return config_dict
class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
cleanup_func,
name="test",
datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
If no datastore is supplied, one is created and given to the homeserver.
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
if reactor is None:
from twisted.internet import reactor
if config is None:

Amber Brown
committed
config = default_config(name, parse=True)
if "clock" not in kargs:
kargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
"name": "psycopg2",
"args": {
"database": test_db,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"user": POSTGRES_USER,
"cp_min": 1,
"cp_max": 5,
},
"name": "sqlite3",
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if datastore is None and isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
cur.execute(
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
)
cur.close()
db_conn.close()
if datastore is None:
version_string="Synapse/tests",
**kargs
)
hs.setup()
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
def cleanup():
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
# Try a few times to drop the DB. Some things may hold on to the
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
for x in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e), category=UserWarning
)
time.sleep(0.5)
if not dropped:
warnings.warn("Failed to drop old DB.", category=UserWarning)
if not LEAVE_DB:
# Register the cleanup hook
cleanup_func(cleanup)
version_string="Synapse/tests",
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)
def register_federation_servlets(hs, resource):
federation_server.register_servlets(
hs,
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(), config=hs.config.rc_federation
),
)
def get_mock_call_args(pattern_func, mock_func):
""" Return the arguments the mock function was called with interpreted
by the pattern functions argument list.
"""
invoked_args, invoked_kargs = mock_func.call_args
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
def getRawHeaders(name, default=None):
return headers.get(name, default)
return getRawHeaders

Paul "LeoNerd" Evans
committed
# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
def trigger_get(self, path):
return self.trigger(b"GET", path, None)
self, http_method, path, content, mock_request, federation_auth_origin=None
""" Fire an HTTP event.
Args:
http_method : The HTTP method
path : The HTTP path
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
KeyError If no event is found which will handle the path.
"""
path = self.prefix + path
# annoyingly we return a twisted http request which has chained calls
# to get at the http content, hence mock it here.
mock_content = Mock()
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method.encode("ascii")
mock_request.uri = path.encode("ascii")
if federation_auth_origin is not None:
headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
# add in query params to the right place
try:
mock_request.args = urlparse.parse_qs(path.split("?")[1])
mock_request.path = path.split("?")[0]
if isinstance(path, bytes):
for (method, pattern, func) in self.callbacks:
if http_method != method:
continue
matcher = pattern.match(path)
if matcher:
try:
(code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response
return (e.code, cs_error(e.msg, code=e.errcode))
raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
class MockKey(object):
alg = "mock_alg"
version = "mock_version"
@property
def verify_key(self):
return self
def sign(self, message):
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
class MockClock(object):
now = 1000

Paul "LeoNerd" Evans
committed
# list of lists of [absolute_time, callback, expired] in no particular
# order
def time(self):
return self.now
def time_msec(self):
return self.time() * 1000
ctx = current_context()
set_current_context(ctx)

Paul "LeoNerd" Evans
committed
t = [self.now + delay, wrapped_callback, False]
self.timers.append(t)

Paul "LeoNerd" Evans
committed
def looping_call(self, function, interval):
self.loopers.append([function, interval / 1000.0, self.now])
def cancel_call_later(self, timer, ignore_errs=False):

Paul "LeoNerd" Evans
committed
if timer[2]:
if not ignore_errs:
raise Exception("Cannot cancel an expired timer")

Paul "LeoNerd" Evans
committed
timer[2] = True
self.timers = [t for t in self.timers if t != timer]
# For unit testing
def advance_time(self, secs):
self.now += secs
timers = self.timers
self.timers = []

Paul "LeoNerd" Evans
committed
for t in timers:
time, callback, expired = t
if expired:
raise Exception("Timer already expired")

Paul "LeoNerd" Evans
committed
t[2] = True

Paul "LeoNerd" Evans
committed
self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
if last + interval < self.now:
func()
looped[2] = self.now
def time_bound_deferred(self, d, *args, **kwargs):
# We don't bother timing things out for now.
return d

Paul "LeoNerd" Evans
committed
def _format_call(args, kwargs):
return ", ".join(
["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]

Paul "LeoNerd" Evans
committed
)
class DeferredMockCallable(object):
"""A callable instance that stores a set of pending call expectations and
return values for them. It allows a unit test to assert that the given set
of function calls are eventually made, by awaiting on them to be called.
"""
def __init__(self):
self.expectations = []
self.calls = []

Paul "LeoNerd" Evans
committed
def __call__(self, *args, **kwargs):
self.calls.append((args, kwargs))

Paul "LeoNerd" Evans
committed
if not self.expectations:
raise ValueError(
"%r has no pending calls to handle call(%s)"
% (self, _format_call(args, kwargs))

Paul "LeoNerd" Evans
committed
)
for (call, result, d) in self.expectations:
if args == call[1] and kwargs == call[2]:
d.callback(None)
return result
failure = AssertionError(
"Was not expecting call(%s)" % (_format_call(args, kwargs))
)

Paul "LeoNerd" Evans
committed

Erik Johnston
committed
for _, _, d in self.expectations:
try:
d.errback(failure)

Erik Johnston
committed
pass
raise failure

Paul "LeoNerd" Evans
committed
def expect_call_and_return(self, call, result):
self.expectations.append((call, result, defer.Deferred()))
@defer.inlineCallbacks

Erik Johnston
committed
def await_calls(self, timeout=1000):
deferred = defer.DeferredList(

Erik Johnston
committed
)
timer = reactor.callLater(

Erik Johnston
committed
deferred.errback,
AssertionError(
"%d pending calls left: %s"
% (
len([e for e in self.expectations if not e[2].called]),
[e for e in self.expectations if not e[2].called],
)
),

Erik Johnston
committed
)
yield deferred
timer.cancel()
self.calls = []
def assert_had_no_calls(self):
if self.calls:
calls = self.calls
self.calls = []
"Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
@defer.inlineCallbacks
def create_room(hs, room_id, creator_id):
"""Creates and persist a creation event for the given room
Args:
hs
room_id (str)
creator_id (str)
"""
persistence_store = hs.get_storage().persistence
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
yield store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
room_version=RoomVersions.V1,
)
builder = event_builder_factory.for_room_version(
{
"type": EventTypes.Create,
"state_key": "",
"sender": creator_id,
"room_id": room_id,
"content": {},
event, context = yield event_creation_handler.create_new_client_event(builder)
yield persistence_store.persist_event(event, context)