From fd7c7434457e215d73873748604f430c52586498 Mon Sep 17 00:00:00 2001
From: Patrick Cloke <clokep@users.noreply.github.com>
Date: Fri, 30 Oct 2020 07:15:07 -0400
Subject: [PATCH] Fail test cases if they fail to await all awaitables (#8690)

---
 changelog.d/8690.misc        |  1 +
 tests/test_utils/__init__.py | 34 +++++++++++++++++++++++++++++++++-
 tests/unittest.py            |  6 +++++-
 3 files changed, 39 insertions(+), 2 deletions(-)
 create mode 100644 changelog.d/8690.misc

diff --git a/changelog.d/8690.misc b/changelog.d/8690.misc
new file mode 100644
index 0000000000..0f38ba1f5d
--- /dev/null
+++ b/changelog.d/8690.misc
@@ -0,0 +1 @@
+Fail tests if they do not await coroutines.
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
 """
 Utilities for running the unit tests
 """
+import sys
+import warnings
 from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
     future = Future()  # type: ignore
     future.set_result(result)
     return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+    """
+    Convert warnings from a non-awaited coroutines into errors.
+    """
+    warnings.simplefilter("error", RuntimeWarning)
+
+    # unraisablehook was added in Python 3.8.
+    if not hasattr(sys, "unraisablehook"):
+        return lambda: None
+
+    # State shared between unraisablehook and check_for_unraisable_exceptions.
+    unraisable_exceptions = []
+    orig_unraisablehook = sys.unraisablehook  # type: ignore
+
+    def unraisablehook(unraisable):
+        unraisable_exceptions.append(unraisable.exc_value)
+
+    def cleanup():
+        """
+        A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+        """
+        sys.unraisablehook = orig_unraisablehook  # type: ignore
+        if unraisable_exceptions:
+            raise unraisable_exceptions.pop()
+
+    sys.unraisablehook = unraisablehook  # type: ignore
+
+    return cleanup
diff --git a/tests/unittest.py b/tests/unittest.py
index 257f465897..08cf9b10c5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -54,7 +54,7 @@ from tests.server import (
     render,
     setup_test_homeserver,
 )
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
 
                 logging.getLogger().setLevel(level)
 
+            # Trial messes with the warnings configuration, thus this has to be
+            # done in the context of an individual TestCase.
+            self.addCleanup(setup_awaitable_errors())
+
             return orig()
 
         @around(self)
-- 
GitLab