diff --git a/changelog.d/11095.misc b/changelog.d/11095.misc
new file mode 100644
index 0000000000000000000000000000000000000000..786e90b59526070e8ff3d7e027919fa373eecdd3
--- /dev/null
+++ b/changelog.d/11095.misc
@@ -0,0 +1 @@
+Add type hints to most `HomeServer` parameters.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index bb4d53d778917b7298f8b6360a9e399ea0fe8d6d..2ca2e051e43aac1a9d52a01bb7cd34b2b7bbe4db 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -294,7 +294,7 @@ def listen_ssl(
     return r
 
 
-def refresh_certificate(hs):
+def refresh_certificate(hs: "HomeServer"):
     """
     Refresh the TLS certificates that Synapse is using by re-reading them from
     disk and updating the TLS context factories to use them.
@@ -419,11 +419,11 @@ async def start(hs: "HomeServer"):
         atexit.register(gc.freeze)
 
 
-def setup_sentry(hs):
+def setup_sentry(hs: "HomeServer"):
     """Enable sentry integration, if enabled in configuration
 
     Args:
-        hs (synapse.server.HomeServer)
+        hs
     """
 
     if not hs.config.metrics.sentry_enabled:
@@ -449,7 +449,7 @@ def setup_sentry(hs):
         scope.set_tag("worker_name", name)
 
 
-def setup_sdnotify(hs):
+def setup_sdnotify(hs: "HomeServer"):
     """Adds process state hooks to tell systemd what we are up to."""
 
     # Tell systemd our state, if we're using it. This will silently fail if
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b156b93bf3d15c42f4cfc7f20885308236bc9dc3..2fc848596d619ed812adfbd7447a224ad7643f59 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer):
     DATASTORE_CLASS = AdminCmdSlavedStore
 
 
-async def export_data_command(hs, args):
+async def export_data_command(hs: HomeServer, args):
     """Export data for a user.
 
     Args:
-        hs (HomeServer)
+        hs
         args (argparse.Namespace)
     """
 
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 7489f31d9addc5d7a85d3d8360902d7214d2badf..51eadf122dbaa987efc74f549414c1af701a0303 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet):
 
     PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: HomeServer):
         """
         Args:
-            hs (synapse.server.HomeServer): server
+            hs: server
         """
         super().__init__()
         self.auth = hs.get_auth()
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 422f03cc0464928359f77b3a1f16202797fa1b2b..93e22992661cd6ac9f3fecdc7122df85a3f3a208 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
         e = e.__cause__
 
 
-def run(hs):
+def run(hs: HomeServer):
     PROFILE_SYNAPSE = False
     if PROFILE_SYNAPSE:
 
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index fcd01e833c8489ae1f5e630f5c83eae30ea7a4fc..126450e17a46a99f8ae94e8f6ca02713f8865d48 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -15,11 +15,15 @@ import logging
 import math
 import resource
 import sys
+from typing import TYPE_CHECKING
 
 from prometheus_client import Gauge
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger("synapse.app.homeserver")
 
 # Contains the list of processes we will be monitoring
@@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge(
 
 
 @wrap_as_background_process("phone_stats_home")
-async def phone_stats_home(hs, stats, stats_process=_stats_process):
+async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
     logger.info("Gathering stats for reporting")
     now = int(hs.get_clock().time())
     uptime = int(now - hs.start_time)
@@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
         logger.warning("Error reporting stats: %s", e)
 
 
-def start_phone_stats_home(hs):
+def start_phone_stats_home(hs: "HomeServer"):
     """
     Start the background tasks which report phone home stats.
     """
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 935f24263c981bb8fb7ce2eec178c5684ccf700a..d08f6bbd7f2e2ecc186107ea91f66114e52483f8 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
     from synapse.appservice import ApplicationService
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient):
     pushing.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.clock = hs.get_clock()
 
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 0a08231e5a2db8c9bc58d513d89f8bd07f870f37..5252e61a99a0fe5085ff7daf017fc25044ebb814 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -18,6 +18,7 @@ import os
 import sys
 import threading
 from string import Template
+from typing import TYPE_CHECKING
 
 import yaml
 from zope.interface import implementer
@@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string
 
 from ._base import Config, ConfigError
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 DEFAULT_LOG_CONFIG = Template(
     """\
 # Log configuration for Synapse.
@@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path):
 
 
 def setup_logging(
-    hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
+    hs: "HomeServer",
+    config,
+    use_worker_options=False,
+    logBeginner: LogBeginner = globalLogBeginner,
 ) -> None:
     """
     Set up the logging subsystem.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 0cd424e12aa153dda77d26624f377db1caeece34..f56344a3b94f82ccc2941f93cfb65e4fd032cee0 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import logging
 from collections import namedtuple
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
@@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
 from synapse.types import JsonDict, get_domain_from_id
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class FederationBase:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
         self.server_name = hs.hostname
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d8c0b86f230167d643c8812c8ea02b27124ca1d3..0d66034f44e9f5edfe6d8d5e33008051f24b2d22 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -467,7 +467,7 @@ class FederationServer(FederationBase):
 
     async def on_room_state_request(
         self, origin: str, room_id: str, event_id: Optional[str]
-    ) -> Tuple[int, Dict[str, Any]]:
+    ) -> Tuple[int, JsonDict]:
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
 
@@ -481,7 +481,7 @@ class FederationServer(FederationBase):
         # - but that's non-trivial to get right, and anyway somewhat defeats
         # the point of the linearizer.
         with (await self._server_linearizer.queue((origin, room_id))):
-            resp = dict(
+            resp: JsonDict = dict(
                 await self._state_resp_cache.wrap(
                     (room_id, event_id),
                     self._on_context_state_request_compute,
@@ -1061,11 +1061,12 @@ class FederationServer(FederationBase):
 
                 origin, event = next
 
-            lock = await self.store.try_acquire_lock(
+            new_lock = await self.store.try_acquire_lock(
                 _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
             )
-            if not lock:
+            if not new_lock:
                 return
+            lock = new_lock
 
     def __str__(self) -> str:
         return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4f592246860bf509e4b287b57a5c549969827c31..203d723d412059b9fc05d25b3465973a15ad49ac 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -21,6 +21,7 @@ import typing
 import urllib.parse
 from io import BytesIO, StringIO
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Dict,
     Generic,
@@ -73,6 +74,9 @@ from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 outgoing_requests_counter = Counter(
@@ -319,7 +323,7 @@ class MatrixFederationHttpClient:
             requests.
     """
 
-    def __init__(self, hs, tls_client_options_factory):
+    def __init__(self, hs: "HomeServer", tls_client_options_factory):
         self.hs = hs
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
@@ -711,7 +715,7 @@ class MatrixFederationHttpClient:
         Returns:
             A list of headers to be added as "Authorization:" headers
         """
-        request = {
+        request: JsonDict = {
             "method": method.decode("ascii"),
             "uri": url_bytes.decode("ascii"),
             "origin": self.server_name,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 897ba5e4531bcab7380392218d3a292062c968f6..1af0d9a31d1f39ce0aa5e6fdbf25a45ec7dff3c8 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,6 +22,7 @@ import urllib
 from http import HTTPStatus
 from inspect import isawaitable
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -61,6 +62,9 @@ from synapse.util import json_encoder
 from synapse.util.caches import intern_dict
 from synapse.util.iterutils import chunk_seq
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource):
         return_json_error(f, request)
 
 
+_PathEntry = collections.namedtuple(
+    "_PathEntry", ["pattern", "callback", "servlet_classname"]
+)
+
+
 class JsonResource(DirectServeJsonResource):
     """This implements the HttpServer interface and provides JSON support for
     Resources.
@@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource):
 
     isLeaf = True
 
-    _PathEntry = collections.namedtuple(
-        "_PathEntry", ["pattern", "callback", "servlet_classname"]
-    )
-
-    def __init__(self, hs, canonical_json=True, extract_context=False):
+    def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
-        self.path_regexs = {}
+        self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
         self.hs = hs
 
     def register_paths(self, method, path_patterns, callback, servlet_classname):
@@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource):
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
             self.path_regexs.setdefault(method, []).append(
-                self._PathEntry(path_pattern, callback, servlet_classname)
+                _PathEntry(path_pattern, callback, servlet_classname)
             )
 
     def _get_handler_for_request(
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index ba8114ac9e13bcaf09c2ee5cbf37dfc76c21d6aa..1457d9d59b1f83abc74bc1b3aa109c66e8931681 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
     account_data,
@@ -26,16 +28,19 @@ from synapse.replication.http import (
     streams,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 REPLICATION_PREFIX = "/_synapse/replication"
 
 
 class ReplicationRestResource(JsonResource):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # We enable extracting jaeger contexts here as these are internal APIs.
         super().__init__(hs, canonical_json=False, extract_context=True)
         self.register_servlets(hs)
 
-    def register_servlets(self, hs):
+    def register_servlets(self, hs: "HomeServer"):
         send_event.register_servlets(hs, self)
         federation.register_servlets(hs, self)
         presence.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index e047ec74d85f64c24466a97a8a81697993ef9a36..585332b244a4742b3254a8cf446bb414f11a0b93 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,7 +17,7 @@ import logging
 import re
 import urllib
 from inspect import signature
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
 
 from prometheus_client import Counter, Gauge
 
@@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         pass
 
     @classmethod
-    def make_client(cls, hs):
+    def make_client(cls, hs: "HomeServer"):
         """Create a client that makes requests.
 
         Returns a callable that accepts the same parameters as
@@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                     url_args.append(txn_id)
 
                 if cls.METHOD == "POST":
-                    request_func = client.post_json_get_json
+                    request_func: Callable[
+                        ..., Awaitable[Any]
+                    ] = client.post_json_get_json
                 elif cls.METHOD == "PUT":
                     request_func = client.put_json
                 elif cls.METHOD == "GET":
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 70e951af63767cfafbb0746afe12d74e91b6e551..5f0f225aa953889c12f48463706329123bb4db56 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -13,10 +13,14 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id", "account_data_type")
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.handler = hs.get_account_data_handler()
@@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id", "room_id", "account_data_type")
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.handler = hs.get_account_data_handler()
@@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id", "room_id", "tag")
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.handler = hs.get_account_data_handler()
@@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
     )
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.handler = hs.get_account_data_handler()
@@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
         return 200, {"max_stream_id": max_stream_id}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationUserAccountDataRestServlet(hs).register(http_server)
     ReplicationRoomAccountDataRestServlet(hs).register(http_server)
     ReplicationAddTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 5a5818ef61e26cfbc1bbdeeffaf8114e3e2b8504..42dffb39cbef88a122ab1acb27728d36fc6287f2 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,9 +13,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.replication.http._base import ReplicationEndpoint
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.device_list_updater = hs.get_device_handler().device_list_updater
@@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
         return 200, user_devices
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index a0b3145f4e3206285f9da8fae08899af0890a775..5ed535c90dea39b6d6aa45ea0c9710cf112bab2b 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
@@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     NAME = "fed_send_events"
     PATH_ARGS = ()
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     NAME = "fed_send_edu"
     PATH_ARGS = ("edu_type",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
     # This is a query, so let's not bother caching
     CACHE = False
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
     NAME = "fed_cleanup_room"
     PATH_ARGS = ("room_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
     NAME = "store_room_on_outlier_membership"
     PATH_ARGS = ("room_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationFederationSendEventsRestServlet(hs).register(http_server)
     ReplicationFederationSendEduRestServlet(hs).register(http_server)
     ReplicationGetQueryRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 550bd5c95f8dbd969eebac149ab5baa80cf37c0e..0db419ea57fb393d5329255697511370078ea784 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -13,10 +13,14 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
     NAME = "device_check_registered"
     PATH_ARGS = ("user_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.registration_handler = hs.get_registration_handler()
 
@@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         return 200, res
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     RegisterDeviceReplicationServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 34206c5060664eabb7a004e2961b8aeb33d95140..7371c240b2744a70eb74ac3ef7af4382782a8e5b 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     NAME = "remote_join"
     PATH_ARGS = ("room_id", "user_id")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation_handler = hs.get_federation_handler()
@@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id", "user_id", "change")
     CACHE = False  # No point caching as should return instantly.
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.registeration_handler = hs.get_registration_handler()
@@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationRemoteJoinRestServlet(hs).register(http_server)
     ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
     ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index bb00247953495ad50f705d0a54b28b15da06ae10..63143085d5213ac314991d58ea8e2667b394be51 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationBumpPresenceActiveTime(hs).register(http_server)
     ReplicationPresenceSetState(hs).register(http_server)
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index 139427cb1f29fa50ceb8e7f7972f2434782d5766..6c8db3061ee2bcdf762a15fd4d951af112934992 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationRemovePusherRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d6dd7242eb204f0cd7d61a427415c2675b70cce4..7adfbb666f392c4b429ff13591de647e2d2ff7b8 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -13,10 +13,14 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     NAME = "register_user"
     PATH_ARGS = ("user_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
@@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
     NAME = "post_register"
     PATH_ARGS = ("user_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
@@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationRegisterServlet(hs).register(http_server)
     ReplicationPostRegisterActionsServlet(hs).register(http_server)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index fae5ffa451d37d0e8964fa565ebe07996b339ea8..9f6851d0592eb7cdc4fd35416662e34583637a36 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
@@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import Requester, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
     NAME = "send_event"
     PATH_ARGS = ("event_id",)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationSendEventRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index 9afa147d00c109c802a062660d1a5a87323bcad4..3223bc2432b293658042613fdfd075209f19fe01 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -13,11 +13,15 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer
 from synapse.replication.http._base import ReplicationEndpoint
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
     PATH_ARGS = ("stream_name",)
     METHOD = "GET"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self._instance_name = hs.get_instance_name()
@@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index e460dd85cd833bd92fe7a4277326a032c54e7d9f..7ecb446e7c785ca296509845e51da17471e34600 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -13,18 +13,21 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen: Optional[
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 436d39c3203fb4daed951c6ce83d681ac45e18fd..61cd7e5228007f6c800565d524b63437e3bb85a8 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -12,15 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.lrucache import LruCache
 
 from ._base import BaseSlavedStore
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 26bdead5651e8fb494e6dc9de3d6c78abfdfb14f..0a582960896d409f6b68b28886e46981168a6f33 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
@@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d4d3f8c44876baf02c6788dbe083be4866f1933c..63ed50caa5eb50dde506f30dac4277678918cfc5 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
@@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -54,7 +58,7 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 37875bc9730fb11c90a29b096ef945e84d2b235d..90284c202d55b4ea64f0554ae77767c75f2c76a3 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -12,14 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index e9bdc3847006b0d5c6818ad14469d6f177fe7c41..497e16c69e6a9ca8baf405f67da6c021f76ad8d7 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import GroupServerStream
@@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index b402f82810faa9edafd5a66174c32a259f9534a5..aaf91e5e025309ba6d41c662e801f588551d65db 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.util import json_decoder, json_encoder
 
 if TYPE_CHECKING:
+    from txredisapi import RedisProtocol
+
     from synapse.server import HomeServer
 
 set_counter = Counter(
@@ -59,7 +61,12 @@ class ExternalCache:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self._redis_connection = hs.get_outbound_redis_connection()
+        if hs.config.redis.redis_enabled:
+            self._redis_connection: Optional[
+                "RedisProtocol"
+            ] = hs.get_outbound_redis_connection()
+        else:
+            self._redis_connection = None
 
     def _get_redis_key(self, cache_name: str, key: str) -> str:
         return "cache_v1:%s:%s" % (cache_name, key)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 6aa9318027764635c8a40700b27faeeb387c6f30..06fd06fdf3a6616fea12333fcf83cc4cc2d655dd 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -294,7 +294,7 @@ class ReplicationCommandHandler:
             # This shouldn't be possible
             raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
 
-    def start_replication(self, hs):
+    def start_replication(self, hs: "HomeServer"):
         """Helper method to start a replication connection to the remote server
         using TCP.
         """
@@ -321,6 +321,8 @@ class ReplicationCommandHandler:
                 hs.config.redis.redis_host,  # type: ignore[arg-type]
                 hs.config.redis.redis_port,
                 self._factory,
+                timeout=30,
+                bindAddress=None,
             )
         else:
             client_name = hs.get_instance_name()
@@ -331,6 +333,8 @@ class ReplicationCommandHandler:
                 host,  # type: ignore[arg-type]
                 port,
                 self._factory,
+                timeout=30,
+                bindAddress=None,
             )
 
     def get_streams(self) -> Dict[str, Stream]:
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 80f9b23bfd74eb8a619ffab66a41d5c0334211a1..55326877fd2c9a0dc374a2d76a037f79695bb1bf 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -16,6 +16,7 @@
 
 import logging
 import random
+from typing import TYPE_CHECKING
 
 from prometheus_client import Counter
 
@@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
 from synapse.replication.tcp.streams import EventsStream
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 stream_updates_counter = Counter(
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
 )
@@ -37,7 +41,7 @@ logger = logging.getLogger(__name__)
 class ReplicationStreamProtocolFactory(Factory):
     """Factory for new replication connections."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.command_handler = hs.get_tcp_replication()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server.server_name
@@ -65,7 +69,7 @@ class ReplicationStreamer:
     data is available it will propagate to all connected clients.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9b905aba9dbbbe943fd88d0beeec0d95fffba54c..c8b188ae4ea4a190d7fbeb65d5660bd95e7627b3 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -241,7 +241,7 @@ class BackfillStream(Stream):
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -363,7 +363,7 @@ class ReceiptsStream(Stream):
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -380,7 +380,7 @@ class PushRulesStream(Stream):
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         super().__init__(
@@ -405,7 +405,7 @@ class PushersStream(Stream):
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
 
         super().__init__(
@@ -438,7 +438,7 @@ class CachesStream(Stream):
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -459,7 +459,7 @@ class DeviceListsStream(Stream):
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -476,7 +476,7 @@ class ToDeviceStream(Stream):
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -495,7 +495,7 @@ class TagAccountDataStream(Stream):
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -582,7 +582,7 @@ class GroupServerStream(Stream):
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
@@ -599,7 +599,7 @@ class UserSignatureStream(Stream):
     NAME = "user_signature"
     ROW_TYPE = UserSignatureStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index a6fa03c90f0a84085f8a36c0268ece09f248aef5..80fbf32f17df68b2e5517ea59c56f969fac172b7 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         """
         Args:
-            hs (synapse.server.HomeServer): server
+            hs: server
         """
         self.hs = hs
         self.auth = hs.get_auth()
diff --git a/synapse/server.py b/synapse/server.py
index a64c846d1c490a1100c3d447bdad9712b7386b10..0fbf36ba991e693e649966bf8ea18b37c241bc48 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta):
         return ExternalCache(self)
 
     @cache_in_self
-    def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
-        if not self.config.redis.redis_enabled:
-            return None
+    def get_outbound_redis_connection(self) -> "RedisProtocol":
+        """
+        The Redis connection used for replication.
+
+        Raises:
+            AssertionError: if Redis is not enabled in the homeserver config.
+        """
+        assert self.config.redis.redis_enabled
 
         # We only want to import redis module if we're using it, as we have
         # `txredisapi` as an optional dependency.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index f5a8f90a0f9834a2f1592e9f898ac7dd1b092983..fa4e89d35cd18b63ca63513bcd041db0166dcd33 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -19,6 +19,7 @@ from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     Collection,
@@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # python 3 does not have a maximum int value
 MAX_TXN_ID = 2 ** 63 - 1
 
@@ -392,7 +396,7 @@ class DatabasePool:
 
     def __init__(
         self,
-        hs,
+        hs: "HomeServer",
         database_config: DatabaseConnectionConfig,
         engine: BaseDatabaseEngine,
     ):
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 20b755056b7f35d44a5ec0dc6aeb3ceb98df7a2e..cfe887b7f73d30e6ff06ae5ca0eeed06aa2a1892 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -13,33 +13,49 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
 
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_conn
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.databases.state import StateGroupDataStore
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-class Databases:
+DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
+
+
+class Databases(Generic[DataStoreT]):
     """The various databases.
 
     These are low level interfaces to physical databases.
 
     Attributes:
-        main (DataStore)
+        databases
+        main
+        state
+        persist_events
     """
 
-    def __init__(self, main_store_class, hs):
+    databases: List[DatabasePool]
+    main: DataStoreT
+    state: StateGroupDataStore
+    persist_events: Optional[PersistEventsStore]
+
+    def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
         # Note we pass in the main store class here as workers use a different main
         # store.
 
         self.databases = []
-        main = None
-        state = None
-        persist_events = None
+        main: Optional[DataStoreT] = None
+        state: Optional[StateGroupDataStore] = None
+        persist_events: Optional[PersistEventsStore] = None
 
         for database_config in hs.config.database.databases:
             db_name = database_config.name
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5c21402deab96808019429470ff275c210c24ead..259cae5b3711734ebe0700ba8c8c360165e32457 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import DatabasePool
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
 from .user_directory import UserDirectoryStore
 from .user_erasure_store import UserErasureStore
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -126,7 +129,7 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 70ca3e09f7c615308c0c572bbdb64327c488c7bc..f8bec266ac416b7ec406d7410a9353d86de4a513 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -28,6 +28,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c57ae5ef15c6a9c3f8a27190365339b7c91761cb..36e8422fc63b8b22da39cb062a276ce5d6e8b869 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3154906d45f66e63ea9caada5420b5b9595d56ab..81431681070301ae63e628d005d4516496a14f62 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -26,11 +26,14 @@ from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 64645203865cf40fdc065c457d482c7cc611a7b1..a01bf2c5b7f180009932762afe2c69c3fa97d278 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,17 @@
 # limitations under the License.
 import abc
 import logging
-from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ba9f71a2303309d6e3faff6ed616fb915467cb5d..ef5d1ef01e4875814d00061bcaac389579544679 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,7 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter, Gauge
 
@@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 oldest_pdu_in_federation_staging = Gauge(
     "synapse_federation_server_oldest_inbound_pdu_in_staging",
     "The age in seconds since we received the oldest pdu in the federation staging area",
@@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 97b3e92d3f13d48abe638613372e1b3c59b2585a..d957e770dcd82ca1506a7fef70de5e3a40f815f6 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
 
@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 1afc59fafbf28dc92b8c3092dbac887b4182c684..fc491120632a155729939b2ad44100b11af584ce 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -76,7 +79,7 @@ class _CalculateChainCover:
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2fa945d171f04edd14cbbf346447e254ea5301e3..717487be28e5e4a28fe3036f6c020589ca6401d4 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,11 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from enum import Enum
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
     "media_repository_drop_index_wo_method"
 )
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self.server_name = hs.hostname
 
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dac3d14da8e1d929269f7080dc97201cfac451c5..d901933ae4f2880634ad40290afbf4a628e0be56 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,7 +14,7 @@
 import calendar
 import logging
 import time
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Collect metrics on the number of forward extremities that exist.
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     stats and prometheus metrics.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index a14ac03d4b6eeb9c901cd4f033f69e1e45fbd9cf..b5284e4f67838b1469d129600b78f6dbf2f1b857 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,13 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fc720f59478bdebe985440557231989e96483ce9..fa782023d4eee92627709b26f0cc368e5afa3c57 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import logging
-from typing import Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
 
 from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
@@ -33,6 +33,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -75,7 +78,7 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 01a42813011af71e99ae13191a87660cc4eec4c6..c99f8aebdbdddbcd91064ec2a6e8a2ce11809d8d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -29,11 +29,14 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 835d7889cbe91e1ebacf5d7f3a05354d771acb08..f879bbe7c720b3e4ed39d6c0e0fc59493bc31d85 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
 import logging
 from abc import abstractmethod
 from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
@@ -32,6 +32,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import MXC_REGEX
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
 
 
 class RoomWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ddb162a4fca13fe1033bc5bb845558e8d8e863bc..4b288bb2e772951dfe89503d03ccfd42e75d8ff5 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
+    from synapse.server import HomeServer
     from synapse.state import _StateCacheEntry
 
 logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index c85383c97542cb45b083b808a87f064a7dff3858..7fe233767f763e97ef8fa168a9631c0214dc2e08 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,7 +15,7 @@
 import logging
 import re
 from collections import namedtuple
-from typing import Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 SearchEntry = namedtuple(
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if not hs.config.server.enable_search:
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a8e8dd4577c4df18499a8c2f689ca2e06957b469..fa2c3b1feb91fea6824bc3fc1d6a3b4bf9a77f08 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -15,7 +15,7 @@
 import collections.abc
 import logging
 from collections import namedtuple
-from typing import Iterable, Optional, Set
+from typing import TYPE_CHECKING, Iterable, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -30,6 +30,9 @@ from synapse.types import StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -53,7 +56,7 @@ class _GetStateGroupDelta(
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e20033bb2840f295cd70b86d7caf87e5a39adf14..5d7b59d861c971606b51c27ade933f359ee7ef99 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from typing_extensions import Counter
 
@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # these fields track absolutes (e.g. total number of rooms on the server)
@@ -93,7 +96,7 @@ class UserSortOrder(Enum):
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 860146cd1bc9cd1c92f6357a6e229347bdd16505..d7dc1f73ac16237362007ff53c7fe66421d610b8 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 db_binary_type = memoryview
 
 logger = logging.getLogger(__name__)
@@ -57,7 +60,7 @@ class DestinationRetryTimings:
 
 
 class TransactionWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 0e8270746d7870962c0b7be6edaad7ed1bfd5cac..402f134d894b1805ea967e6d826ed2d2d2b0cc0b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,6 +18,7 @@ import itertools
 import logging
 from collections import deque
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -56,6 +57,9 @@ from synapse.types import (
 from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # The number of times we are recalculating the current state
@@ -272,7 +276,7 @@ class EventsPersistenceStorage:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs, stores: Databases):
+    def __init__(self, hs: "HomeServer", stores: Databases):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.