From d3523e3e9727f0a81f88e9aa58a8f0fc2b3ee260 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Fri, 13 Nov 2020 22:34:08 +0000
Subject: [PATCH] pass a Site into RestHelper

---
 tests/rest/client/v1/utils.py | 11 ++++++-----
 tests/unittest.py             |  2 +-
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..dc789fbdaa 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -23,6 +23,7 @@ from typing import Any, Dict, Optional
 import attr
 
 from twisted.web.resource import Resource
+from twisted.web.server import Site
 
 from synapse.api.constants import Membership
 
@@ -36,7 +37,7 @@ class RestHelper:
     """
 
     hs = attr.ib()
-    resource = attr.ib()
+    site = attr.ib(type=Site)
     auth_user_id = attr.ib()
 
     def create_room_as(
@@ -54,7 +55,7 @@ class RestHelper:
         request, channel = make_request(
             self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
         )
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert channel.result["code"] == b"%d" % expect_code, channel.result
         self.auth_user_id = temp_id
@@ -128,7 +129,7 @@ class RestHelper:
             self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
         )
 
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -160,7 +161,7 @@ class RestHelper:
         request, channel = make_request(
             self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
         )
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -212,7 +213,7 @@ class RestHelper:
 
         request, channel = make_request(self.hs.get_reactor(), method, path, content)
 
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
diff --git a/tests/unittest.py b/tests/unittest.py
index e36ac89196..0a24c2f6b2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -253,7 +253,7 @@ class HomeserverTestCase(TestCase):
 
         from tests.rest.client.v1.utils import RestHelper
 
-        self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+        self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
-- 
GitLab