Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clean up caching/locking of OIDC metadata load #9362

Merged
merged 6 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9362.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up the code to load the metadata for OpenID Connect identity providers.
89 changes: 53 additions & 36 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -245,6 +246,7 @@ def __init__(

self._token_generator = token_generator

self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str

self._scopes = provider.scopes
Expand All @@ -253,14 +255,16 @@ def __init__(
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = provider.discover

# cache of metadata for the identity provider (endpoint uris, mostly). This is
# loaded on-demand from the discovery endpoint (if discovery is enabled), with
# possible overrides from the config. Access via `load_metadata`.
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)

# cache of JWKs used by the identity provider to sign tokens. Loaded on demand
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)

self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
Expand All @@ -286,7 +290,7 @@ def __init__(

self._sso_handler.register_identity_provider(self)

def _validate_metadata(self):
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.

This checks the validity of the currently loaded provider. Not
Expand All @@ -305,7 +309,6 @@ def _validate_metadata(self):
if self._skip_verification is True:
return

m = self._provider_metadata
m.validate_issuer()
m.validate_authorization_endpoint()
m.validate_token_endpoint()
Expand Down Expand Up @@ -340,11 +343,7 @@ def _validate_metadata(self):
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
if m.get("jwks") is None:
if m.get("jwks_uri") is not None:
m.validate_jwks_uri()
else:
raise ValueError('"jwks_uri" must be set')
m.validate_jwks_uri()

@property
def _uses_userinfo(self) -> bool:
Expand All @@ -361,30 +360,48 @@ def _uses_userinfo(self) -> bool:
or self._user_profile_method == "userinfo_endpoint"
)

async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.

If this is the first call, the metadata is built from the config and from the
metadata discovery endpoint (if enabled), and then validated. If the metadata
is successfully validated, it is then cached for future use.

The values metadatas are discovered if ``oidc_config.discovery`` is
``True`` and then cached.
Args:
force: If true, any cached metadata is discarded to force a reload.

Raises:
ValueError: if something in the provider is not valid

Returns:
The provider's metadata.
"""
# If we are using the OpenID Discovery documents, it needs to be loaded once
# FIXME: should there be a lock here?
if self._provider_needs_discovery:
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
if force:
# reset the cached call to ensure we get a new result
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)

return await self._provider_metadata.get()

async def _load_metadata(self) -> OpenIDProviderMetadata:
# init the metadata from our config
metadata = OpenIDProviderMetadata(
issuer=self._config.issuer,
authorization_endpoint=self._config.authorization_endpoint,
token_endpoint=self._config.token_endpoint,
userinfo_endpoint=self._config.userinfo_endpoint,
jwks_uri=self._config.jwks_uri,
)

# load any data from the discovery endpoint, if enabled
if self._config.discover:
url = get_well_known_url(self._config.issuer, external=True)
metadata_response = await self._http_client.get_json(url)
# TODO: maybe update the other way around to let user override some values?
self._provider_metadata.update(metadata_response)
self._provider_needs_discovery = False
metadata.update(metadata_response)

self._validate_metadata()
self._validate_metadata(metadata)

return self._provider_metadata
return metadata

async def load_jwks(self, force: bool = False) -> JWKS:
"""Load the JSON Web Key Set used to sign ID tokens.
Expand Down Expand Up @@ -414,27 +431,27 @@ async def load_jwks(self, force: bool = False) -> JWKS:
]
}
"""
if force:
# reset the cached call to ensure we get a new result
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
return await self._jwks.get()

async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}

# First check if the JWKS are loaded in the provider metadata.
# It can happen either if the provider gives its JWKS in the discovery
# document directly or if it was already loaded once.
metadata = await self.load_metadata()
jwk_set = metadata.get("jwks")
if jwk_set is not None and not force:
return jwk_set

# Loading the JWKS using the `jwks_uri` metadata
# Load the JWKS using the `jwks_uri` metadata.
uri = metadata.get("jwks_uri")
if not uri:
# this should be unreachable: load_metadata validates that
# there is a jwks_uri in the metadata if _uses_userinfo is unset
raise RuntimeError('Missing "jwks_uri" in metadata')

jwk_set = await self._http_client.get_json(uri)

# Caching the JWKS in the provider's metadata
self._provider_metadata["jwks"] = jwk_set
return jwk_set

async def _exchange_code(self, code: str) -> Token:
Expand Down
129 changes: 129 additions & 0 deletions synapse/util/caches/cached_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union

from twisted.internet.defer import Deferred
from twisted.python.failure import Failure

from synapse.logging.context import make_deferred_yieldable, run_in_background

TV = TypeVar("TV")


class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared

This is useful for wrapping asynchronous functions, where there might be multiple
callers, but we only want to call the underlying function once (and have the result
returned to all callers).

Similar results can be achieved via a lock of some form, but that typically requires
more boilerplate (and ends up being less efficient).

Correctly handles Synapse logcontexts (logs and resource usage for the underlying
function are logged against the logcontext which is active when get() is first
called).

Example usage:

_cached_val = CachedCall(_load_prop)

async def handle_request() -> X:
# We can call this multiple times, but it will result in a single call to
# _load_prop().
return await _cached_val.get()

async def _load_prop() -> X:
await difficult_operation()


The implementation is deliberately single-shot (ie, once the call is initiated,
there is no way to ask for it to be run). This keeps the implementation and
semantics simple. If you want to make a new call, simply replace the whole
CachedCall object.
"""

__slots__ = ["_callable", "_deferred", "_result"]

def __init__(self, f: Callable[[], Awaitable[TV]]):
"""
Args:
f: The underlying function. Only one call to this function will be alive
at once (per instance of CachedCall)
"""
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
self._deferred = None # type: Optional[Deferred]
self._result = None # type: Union[None, Failure, TV]

async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""

# Fire off the callable now if this is our first time
if not self._deferred:
self._deferred = run_in_background(self._callable)

# we will never need the callable again, so make sure it can be GCed
self._callable = None

# once the deferred completes, store the result. We cannot simply leave the
# result in the deferred, since if it's a Failure, GCing the deferred
# would then log a critical error about unhandled Failures.
def got_result(r):
self._result = r

self._deferred.addBoth(got_result)

# TODO: consider cancellation semantics. Currently, if the call to get()
# is cancelled, the underlying call will continue (and any future calls
# will get the result/exception), which I think is *probably* ok, modulo
# the fact the underlying call may be logged to a cancelled logcontext,
# and any eventual exception may not be reported.

# we can now await the deferred, and once it completes, return the result.
await make_deferred_yieldable(self._deferred)

# I *think* this is the easiest way to correctly raise a Failure without having
# to gut-wrench into the implementation of Deferred.
d = Deferred()
d.callback(self._result)
return await d


class RetryOnExceptionCachedCall(Generic[TV]):
"""A wrapper around CachedCall which will retry the call if an exception is thrown

This is used in much the same way as CachedCall, but adds some extra functionality
so that if the underlying function throws an exception, then the next call to get()
will initiate another call to the underlying function. (Any calls to get() which
are already pending will raise the exception.)
"""

slots = ["_cachedcall"]

def __init__(self, f: Callable[[], Awaitable[TV]]):
async def _wrapper() -> TV:
try:
return await f()
except Exception:
# the call raised an exception: replace the underlying CachedCall to
# trigger another call next time get() is called
self._cachedcall = CachedCall(_wrapper)
raise

self._cachedcall = CachedCall(_wrapper)

async def get(self) -> TV:
return await self._cachedcall.get()
Loading