Skip to content
Snippets Groups Projects
Unverified Commit 8b408433 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Allow additional SSO properties to be passed to the client (#8413)

parent ceafb5a1
No related branches found
No related tags found
Loading
Support passing additional single sign-on parameters to the client.
......@@ -1748,6 +1748,14 @@ oidc_config:
#
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{ user.birthdate }}"
# Enable CAS for registration and login.
......
......@@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods:
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
- This method should be async.
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
......@@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods:
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
* `get_extra_attributes(self, userinfo, token)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- Returns a dictionary that is suitable to be serialized to JSON. This
will be returned as part of the response during a successful login.
Note that care should be taken to not overwrite any of the parameters
usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
### Default OpenID Mapping Provider
......
......@@ -243,6 +243,22 @@ for the room are in flight:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
Additionally, the following endpoints should be included if Synapse is configured
to use SSO (you only need to include the ones for whichever SSO provider you're
using):
# OpenID Connect requests.
^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
^/_synapse/oidc/callback$
# SAML requests.
^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
^/_matrix/saml2/authn_response$
# CAS requests.
^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
Note that a HTTP listener with `client` and `federation` resources must be
configured in the `worker_listeners` option in the worker config.
......
......@@ -204,6 +204,14 @@ class OIDCConfig(Config):
# If unset, no displayname will be set.
#
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
# Jinja2 templates for extra attributes to send back to the client during
# login.
#
# Note that these are non-standard and clients will ignore them without modifications.
#
#extra_attributes:
#birthdate: "{{{{ user.birthdate }}}}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
......@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
}
@attr.s(slots=True)
class SsoLoginExtraAttributes:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
creation_time = attr.ib(type=int)
extra_attributes = attr.ib(type=JsonDict)
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
......@@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
async def validate_user_via_ui_auth(
self,
requester: Requester,
......@@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
"""Having figured out a mxid for this user, complete the HTTP request
......@@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
extra_attributes: Extra attributes which will be passed to the client
during successful login. Must be JSON serializable.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
......@@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
self._complete_sso_login(registered_user_id, request, client_redirect_url)
self._complete_sso_login(
registered_user_id, request, client_redirect_url, extra_attributes
)
def _complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
# Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this
# is considered OK since the newest SSO attributes should be most valid.
if extra_attributes:
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
self._clock.time_msec(), extra_attributes,
)
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
......@@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
async def _sso_login_callback(self, login_result: JsonDict) -> None:
"""
A login callback which might add additional attributes to the login response.
Args:
login_result: The data to be sent to the client. Includes the user
ID and access token.
"""
# Expire attributes before processing. Note that there shouldn't be any
# valid logins that still have extra attributes.
self._expire_sso_extra_attributes()
extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
login_result.update(extra_attributes.extra_attributes)
def _expire_sso_extra_attributes(self) -> None:
"""
Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
"""
# TODO This should match the amount of time the macaroon is valid for.
LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
to_expire = set()
for user_id, data in self._extra_attributes.items():
if data.creation_time < expire_before:
to_expire.add(user_id)
for user_id in to_expire:
logger.debug("Expiring extra attributes for user %s", user_id)
del self._extra_attributes[user_id]
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
......
......@@ -37,7 +37,7 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
......@@ -707,6 +707,15 @@ class OidcHandler:
self._render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
......@@ -714,7 +723,7 @@ class OidcHandler:
)
else:
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
......@@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
"""Map a ``UserInfo`` objects into user attributes.
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
......@@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
"""
raise NotImplementedError()
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
"""Map a `UserInfo` object into additional attributes passed to the client during login.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
Returns:
A dict containing additional attributes. Must be JSON serializable.
"""
return {}
# Used to clear out "None" values in templates
def jinja_finalize(thing):
......@@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
extra_attributes = attr.ib() # type: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
......@@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
% (e,)
)
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict):
raise ConfigError(
"oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
)
for key, value in extra_attributes_config.items():
try:
extra_attributes[key] = env.from_string(value)
except Exception as e:
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
% (key, e)
)
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
extra_attributes=extra_attributes,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
......@@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
for key, template in self._config.extra_attributes.items():
try:
extras[key] = template.render(user=userinfo).strip()
except Exception as e:
# Log an error and skip this value (don't break login for this).
logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
return extras
......@@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
callback: Optional[
Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
] = None,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
......@@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet):
Args:
user_id: ID of the user to register.
login_submission: Dictionary of login information.
callback: Callback function to run after registration.
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
Returns:
result: Dictionary of account information after successful registration.
result: Dictionary of account information after successful login.
"""
# Before we actually log them in we check if they've already logged in
......@@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""
Handle the final stage of SSO login.
Args:
login_submission: The JSON request body.
Returns:
The body of the JSON response.
"""
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
result = await self._complete_login(user_id, login_submission)
return result
return await self._complete_login(
user_id, login_submission, self.auth_handler._sso_login_callback
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
......
......@@ -21,7 +21,6 @@ from mock import Mock, patch
import attr
import pymacaroons
from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
......@@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
async def get_extra_attributes(self, userinfo, token):
return {"phone": userinfo["phone"]}
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
......@@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
oidc_config = config.get("oidc_config", {})
oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
......@@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
}
# Update this config with what's in the default config so that
# override_config works as expected.
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
......@@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
@defer.inlineCallbacks
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = yield defer.ensureDeferred(self.handler.load_metadata())
metadata = self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
......@@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
yield defer.ensureDeferred(self.handler.load_metadata())
self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
@defer.inlineCallbacks
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
yield defer.ensureDeferred(self.handler.load_metadata())
self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
@defer.inlineCallbacks
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
jwks = yield defer.ensureDeferred(self.handler.load_jwks())
jwks = self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
yield defer.ensureDeferred(self.handler.load_jwks())
self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
with self.assertRaises(RuntimeError):
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
jwks = self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
......@@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# This should not throw
self.handler._validate_metadata()
@defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
url = yield defer.ensureDeferred(
url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
......@@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
@defer.inlineCallbacks
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
@defer.inlineCallbacks
def test_callback(self):
"""Code callback works and display errors if something went wrong.
......@@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "foo",
"preferred_username": "bar",
}
user_id = UserID("foo", "domain.org")
user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
......@@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
......@@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url,
user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
......@@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
......@@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url,
user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
......@@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
self.handler._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@defer.inlineCallbacks
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
......@@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Missing cookie
request.args = {}
request.getCookie.return_value = None
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
......@@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
@defer.inlineCallbacks
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
......@@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
ret = self.get_success(self.handler._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
......@@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "foo")
self.assertEqual(exc.exception.error_description, "bar")
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
......@@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "server_error")
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.http_client.request = simple_async_mock(
......@@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "internal_server_error"}',
)
)
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "internal_server_error")
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "server_error")
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
......@@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
with self.assertRaises(OidcError) as exc:
yield defer.ensureDeferred(self.handler._exchange_code(code))
self.assertEqual(exc.exception.error, "some_error")
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
{
"oidc_config": {
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
}
}
}
)
def test_extra_attributes(self):
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
token = {
"type": "bearer",
"id_token": "id_token",
"access_token": "access_token",
}
userinfo = {
"sub": "foo",
"phone": "1234567",
}
user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
)
state = "state"
client_redirect_url = "http://client/redirect"
request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request.args = {}
request.args[b"code"] = [b"code"]
request.args[b"state"] = [state.encode("utf-8")]
request.requestHeaders = Mock(spec=["getRawHeaders"])
request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
request.getClientIP.return_value = "10.0.0.1"
self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url, {"phone": "1234567"},
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment