Skip to content
Snippets Groups Projects
Commit 597c79be authored by Erik Johnston's avatar Erik Johnston
Browse files

Change the way we specify if we require auth or not

parent 32fc39fd
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
import functools
......@@ -60,6 +60,16 @@ class TransportLayerServer(JsonResource):
)
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
......@@ -67,7 +77,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request):
def authenticate_request(self, request, content):
json_request = {
"method": request.method,
"uri": request.uri,
......@@ -75,17 +85,10 @@ class Authenticator(object):
"signatures": {},
}
content = None
origin = None
if content is not None:
json_request["content"] = content
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
origin = None
def parse_auth_header(header_str):
try:
......@@ -103,14 +106,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
raise SynapseError(
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
raise SynapseError(
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
......@@ -121,7 +124,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
raise SynapseError(
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
......@@ -130,10 +133,12 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content))
defer.returnValue(origin)
class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler
......@@ -141,29 +146,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, code):
def _wrap(self, func):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
@functools.wraps(code)
def new_code(request, *args, **kwargs):
@functools.wraps(func)
def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try:
(origin, content) = yield authenticator.authenticate_request(request)
origin = yield authenticator.authenticate_request(request, content)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
except:
logger.exception("authenticate_request failed")
raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
response = yield func(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__
new_func.__self__ = func.__self__
return new_code
return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
......@@ -429,9 +451,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks
def on_POST(self, request):
content = parse_json_object_from_request(request)
def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
......@@ -453,11 +476,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet):
"""
......@@ -478,9 +496,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks
def on_GET(self, request):
token = parse_string(request, "access_token")
def on_GET(self, origin, content, query):
token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
......@@ -497,11 +517,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment