Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Land support for multiple OIDC providers #9110

Merged
merged 9 commits into from
Jan 15, 2021
57 changes: 43 additions & 14 deletions synapse/config/oidc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import string
from typing import Optional, Type
from typing import Iterable, Optional, Type

import attr

Expand All @@ -33,16 +33,8 @@ class OIDCConfig(Config):
section = "oidc"

def read_config(self, config, **kwargs):
validate_config(MAIN_CONFIG_SCHEMA, config, ())

self.oidc_provider = None # type: Optional[OidcProviderConfig]

oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
self.oidc_provider = _parse_oidc_config_dict(oidc_config)

if not self.oidc_provider:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return

try:
Expand All @@ -58,7 +50,7 @@ def read_config(self, config, **kwargs):
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
return bool(self.oidc_provider)
return bool(self.oidc_providers)

def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
Expand Down Expand Up @@ -234,7 +226,22 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
},
}

# the `oidc_config` setting can either be None (as it is in the default
# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
"allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
}


# the `oidc_providers` list can either be None (as it is in the default config), or
# a list of provider configs, each of which requires an explicit ID and name.
OIDC_PROVIDER_LIST_SCHEMA = {
"oneOf": [
{"type": "null"},
{"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
]
}

# the `oidc_config` setting can either be None (which it used to be in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property.
#
Expand All @@ -243,12 +250,34 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
# additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}

# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
MAIN_CONFIG_SCHEMA = {
"type": "object",
"properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
"properties": {
"oidc_config": OIDC_CONFIG_SCHEMA,
"oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
},
}


def _parse_oidc_provider_configs(config: JsonDict,) -> Iterable["OidcProviderConfig"]:
"""extract and parse the OIDC provider configs from the config dict

Returns a generator which yields the OidcProviderConfig objects
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""
validate_config(MAIN_CONFIG_SCHEMA, config, ())

for p in config.get("oidc_providers") or []:
yield _parse_oidc_config_dict(p)

# for backwards-compatibility, it is also possible to provide a single "oidc_config"
# object with an "enabled: True" property.
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
clokep marked this conversation as resolved.
Show resolved Hide resolved
yield _parse_oidc_config_dict(oidc_config)


def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig

Expand Down
27 changes: 20 additions & 7 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,28 @@ class OidcHandler:
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()

provider_conf = hs.config.oidc.oidc_provider
provider_confs = hs.config.oidc.oidc_providers
# we should not have been instantiated if there is no configured provider.
assert provider_conf is not None
assert provider_confs

self._token_generator = OidcSessionTokenGenerator(hs)

self._provider = OidcProvider(hs, self._token_generator, provider_conf)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}

async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.

Called at startup to ensure we have everything we need.
"""
await self._provider.load_metadata()
await self._provider.load_jwks()
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
) from e

async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Expand Down Expand Up @@ -184,6 +191,12 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
self._sso_handler.render_error(request, "mismatching_session", str(e))
return

oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
return

if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
Expand All @@ -193,7 +206,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:

code = request.args[b"code"][0].decode()

await self._provider.handle_oidc_callback(request, session_data, code)
await oidc_provider.handle_oidc_callback(request, session_data, code)


class OidcError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)

self.handler = hs.get_oidc_handler()
self.provider = self.handler._provider
self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
Expand Down Expand Up @@ -982,7 +982,7 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData

handler = hs.get_oidc_handler()
provider = handler._provider
provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={})
provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
Expand Down