Skip to content

Commit

Permalink
Refactor OAuthDeviceCode to support non-Entra IdPs (#1892)
Browse files Browse the repository at this point in the history
Co-authored-by: anders-albert <[email protected]>
  • Loading branch information
gregertw and doctrino authored Sep 23, 2024
1 parent 050f53a commit d9631cb
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## [7.62.1] - 2024-09-23
### Changed
- Support for `OAuthDeviceCode` now supports non Entra IdPs

## [7.62.0] - 2024-09-19
### Added
- All `update` methods now accept a new parameter `mode` that controls how non-update objects should be
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

__version__ = "7.62.0"
__version__ = "7.62.1"
__api_subversion__ = "20230101"
263 changes: 237 additions & 26 deletions cognite/client/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import atexit
import inspect
import json
import tempfile
import threading
import time
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Protocol, runtime_checkable
Expand All @@ -15,7 +17,7 @@
from requests_oauthlib import OAuth2Session

from cognite.client.exceptions import CogniteAuthError
from cognite.client.utils._auxiliary import load_resource_to_dict
from cognite.client.utils._auxiliary import at_least_one_is_not_none, exactly_one_is_not_none, load_resource_to_dict

_TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT = 30 # Do not change without also updating all the docstrings using it

Expand Down Expand Up @@ -192,12 +194,15 @@ def authorization_header(self) -> tuple[str, str]:

class _WithMsalSerializableTokenCache:
@staticmethod
def _create_serializable_token_cache(cache_path: Path) -> SerializableTokenCache:
def _create_serializable_token_cache(cache_path: Path, clear_cache: bool = False) -> SerializableTokenCache:
token_cache = SerializableTokenCache()

if cache_path.exists():
with cache_path.open() as fh:
token_cache.deserialize(fh.read())
if clear_cache:
cache_path.unlink(missing_ok=True)
else:
with cache_path.open() as fh:
token_cache.deserialize(fh.read())

def __at_exit() -> None:
if token_cache.has_state_changed:
Expand All @@ -211,29 +216,61 @@ def __at_exit() -> None:
def _resolve_token_cache_path(token_cache_path: Path | None, client_id: str) -> Path:
return token_cache_path or Path(tempfile.gettempdir()) / f"cognitetokencache.{client_id}.bin"

def _create_client_app(self, token_cache_path: Path, client_id: str, authority_url: str) -> PublicClientApplication:
def _create_client_app(
self,
token_cache_path: Path,
client_id: str,
authority_url: str | None = None,
oauth_discovery_url: str | None = None,
clear_cache: bool = False,
mem_cache_only: bool = False,
) -> PublicClientApplication:
from cognite.client.config import global_config

# In addition to caching in memory, we also cache the token on disk so it can be reused across processes:
serializable_token_cache = self._create_serializable_token_cache(token_cache_path)
if authority_url and oauth_discovery_url:
raise ValueError(
"Only one of 'authority_url' (for MS Entra) or 'oauth_discovery_url' (for other IdPs) should be provided."
)

# In addition to caching in memory, we also cache the token on disk so it can be reused across processes.
if mem_cache_only:
serializable_token_cache = SerializableTokenCache()
else:
serializable_token_cache = self._create_serializable_token_cache(token_cache_path, clear_cache)
return PublicClientApplication(
client_id=client_id,
authority=authority_url,
token_cache=serializable_token_cache,
verify=not global_config.disable_ssl,
oidc_authority=oauth_discovery_url,
# These two must be set to `False` to support non-Entra authorities.
instance_discovery=False,
validate_authority=False,
)

@staticmethod
def _get_cached_token(cache_path: Path) -> dict[str, Any]:
if not cache_path.exists():
return {}
token = json.loads(cache_path.read_text())
return token


class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache):
"""OAuth credential provider for the device code login flow.
Args:
authority_url (str): OAuth authority url
client_id (str): Your application's client id.
scopes (list[str]): A list of scopes.
authority_url (str | None): MS Entra OAuth authority url, typically "https://login.microsoftonline.com/{tenant_id}"
client_id (str): Your application's client id that allows device code flows.
scopes (list[str] | None): A list of scopes.
cdf_cluster (str | None): The CDF cluster where the CDF project is located. If provided, scopes will be set to
[f"https://{cdf_cluster}.cognitedata.com/IDENTITY https://{cdf_cluster}.cognitedata.com/user_impersonation openid profile"].
oauth_discovery_url (str | None): Standard OAuth discovery URL, should be where "/.well-known/openid-configuration" is found.
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False
mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False
**token_custom_args (Any): Additional request parameters to pass to the authorization endpoint.
Examples:
>>> from cognite.client.credentials import OAuthDeviceCode
Expand All @@ -242,23 +279,59 @@ class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSeriali
... client_id="abcd",
... scopes=["https://greenfield.cognitedata.com/.default"],
... )
Create credentials with auth0
>>> from cognite.client.credentials import OAuthDeviceCode
>>> oauth_provider = OAuthDeviceCode(
... authority_url=None,
... oauth_discovery_url="https://my-tenant.auth0.com/oauth",
... client_id="abcd",
... scopes=["IDENTITY", "user_impersonation"],
... )
"""

def __init__(
self,
authority_url: str,
authority_url: str | None,
client_id: str,
scopes: list[str],
scopes: list[str] | None = None,
cdf_cluster: str | None = None,
oauth_discovery_url: str | None = None,
token_cache_path: Path | None = None,
token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT,
clear_cache: bool = False,
mem_cache_only: bool = False,
**token_custom_args: Any,
) -> None:
super().__init__(token_expiry_leeway_seconds)
if not exactly_one_is_not_none(authority_url, oauth_discovery_url):
raise ValueError("Either 'authority_url' or 'oauth_discovery_url' must be provided, and not both.")
if not at_least_one_is_not_none(scopes, cdf_cluster):
raise ValueError("Either 'scopes' or 'cdf_cluster' must be provided.")
if not client_id:
raise ValueError("'client_id' must be provided.")
self.__authority_url = authority_url
self.__oauth_discovery_url = oauth_discovery_url
self.__client_id = client_id
self.__scopes = scopes
self.__scopes = scopes or [
f"https://{cdf_cluster}.cognitedata.com/IDENTITY",
f"https://{cdf_cluster}.cognitedata.com/user_impersonation",
"openid",
"profile",
]
self.__mem_cache_only = mem_cache_only
self.__token_custom_args = token_custom_args

self._token_cache_path = self._resolve_token_cache_path(token_cache_path, client_id)
self.__app = self._create_client_app(self._token_cache_path, client_id, authority_url)
self.__app = self._create_client_app(
self._token_cache_path,
client_id,
authority_url,
oauth_discovery_url,
clear_cache,
mem_cache_only,
)

def __getstate__(self) -> dict[str, Any]:
# PublicClientApplication is not picklable, temporarily remove:
Expand All @@ -269,12 +342,23 @@ def __getstate__(self) -> dict[str, Any]:

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.__app = self._create_client_app(self._token_cache_path, self.__client_id, self.__authority_url)
self.__app = self._create_client_app(
token_cache_path=self._token_cache_path,
client_id=self.__client_id,
authority_url=self.__authority_url,
oauth_discovery_url=self.__oauth_discovery_url,
clear_cache=False,
mem_cache_only=self.__mem_cache_only,
)

@property
def authority_url(self) -> str:
def authority_url(self) -> str | None:
return self.__authority_url

@property
def oauth_discovery_url(self) -> str | None:
return self.__oauth_discovery_url

@property
def client_id(self) -> str:
return self.__client_id
Expand All @@ -283,22 +367,100 @@ def client_id(self) -> str:
def scopes(self) -> list[str]:
return self.__scopes

def scope_string(self) -> str:
return " ".join(self.__scopes)

def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]:
"""Return a dictionary with the current token and expiry time."""
if self._token_cache_path.exists():
token = self._get_cached_token(self._token_cache_path)
else:
if _app := getattr(self, f"_{type(self).__name__}__app", None):
if _app.token_cache.has_state_changed:
with open(self._token_cache_path, "w+") as fh:
fh.write(_app.token_cache.serialize())
token = self._get_cached_token(self._token_cache_path)

if convert_timestamps:
if "AccessToken" in token:
for key, value in token["AccessToken"].items():
for subkey in ["expires_on", "extended_expires_on", "cached_at"]:
if subkey in value:
value[subkey] = datetime.fromtimestamp(int(value[subkey])).isoformat()
return token

def _refresh_access_token(self) -> tuple[str, float]:
# First check if a token cache exists on disk. If yes, find and use:
# - A valid access token.
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
credentials = None
if accounts := self.__app.get_accounts():
credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0])

# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow:
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN):
if "expires_on" in token and token["expires_on"] > time.time():
credentials = token
break
if credentials is not None:
credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", ""))
else:
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN):
if expiry := int(token.get("expires_on", 0)) - time.time() > 0:
credentials = {
"access_token": token.get("secret"),
"expires_in": expiry,
}
break
# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
# The msal device_code flow does not support setting the audience, so we need to handle it manually.
# We use the http client instantiated as part of the msal client, as well as the details found
# in oauth discovery.
if credentials is None:
device_flow = self.__app.initiate_device_flow(scopes=self.__scopes)
# print device code user instructions to screen
print(f"Device code: {device_flow['message']}") # noqa: T201
credentials = self.__app.acquire_token_by_device_flow(flow=device_flow)
data = {
"scope": self.scope_string(),
"client_id": self.client_id,
}
for key, value in self.__token_custom_args.items():
data[key] = value
try:
device_flow = self.__app.http_client.post(
self.__app.authority.device_authorization_endpoint,
data=data,
headers={
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
},
).json()
except Exception as e:
raise CogniteAuthError("Error initiating device flow") from e
if "verification_uri" in device_flow:
print( # noqa: T201
f"Visit {device_flow['verification_uri']} and enter the code: {device_flow.get('user_code', 'ERROR')}"
)
elif "message" in device_flow:
print(f"Device code: {device_flow.get('message', device_flow.get('user_code', 'ERROR'))}") # noqa: T201
else:
raise CogniteAuthError(
f"Error initiating device flow: {device_flow.get('error')} - {device_flow.get('error_description')}"
)
if "interval" not in device_flow:
# Set default interval according to standard
device_flow["interval"] = 5
if "expires_in" in device_flow:
# msal library uses expires_at instead of the standard expires_in
device_flow["expires_at"] = device_flow["expires_in"] + time.time()
# Poll for token
credentials = self.__app.client.obtain_token_by_device_flow(
flow=device_flow,
data=dict(
data,
code=device_flow.get(
"device_code"
), # Hack from msal library to get the code from the device flow, not standard
),
)

self._verify_credentials(credentials)
self.__app.token_cache.add(
dict(credentials, environment=self.__app.authority.instance),
)
return credentials["access_token"], time.time() + float(credentials["expires_in"])

@classmethod
Expand Down Expand Up @@ -326,13 +488,62 @@ def load(cls, config: dict[str, Any] | str) -> OAuthDeviceCode:
return cls(
authority_url=loaded["authority_url"],
client_id=loaded["client_id"],
scopes=loaded["scopes"],
scopes=loaded.get("scopes"),
cdf_cluster=loaded.get("cdf_cluster"),
token_cache_path=Path(token_cache_path) if token_cache_path else None,
token_expiry_leeway_seconds=int(
loaded.get("token_expiry_leeway_seconds", _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT)
),
)

@classmethod
def default_for_azure_ad(
cls,
tenant_id: str,
client_id: str,
cdf_cluster: str,
token_cache_path: Path | None = None,
token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT,
clear_cache: bool = False,
mem_cache_only: bool = False,
) -> OAuthDeviceCode:
"""
Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite
app registration for device code flow. If you need device code flow with another app registration, instantiate
OAuthDeviceCode directly.
The default configuration creates the URLs based on the tenant id and cluster:
* Authority URL: "https://login.microsoftonline.com/{tenant_id}"
* Scopes: [f"https://{cdf_cluster}.cognitedata.com/.default"]
Args:
tenant_id (str): The Azure tenant id
client_id (str): An app registration that allows device code flow.
cdf_cluster (str): The CDF cluster where the CDF project is located.
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False
mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False
Returns:
OAuthDeviceCode: An OAuthDeviceCode instance
"""
return cls(
authority_url=f"https://login.microsoftonline.com/{tenant_id}",
client_id=client_id, # Default application for CDF API for device code flow
scopes=[
f"https://{cdf_cluster}.cognitedata.com/IDENTITY",
f"https://{cdf_cluster}.cognitedata.com/user_impersonation",
"profile",
"openid",
],
token_cache_path=token_cache_path,
token_expiry_leeway_seconds=token_expiry_leeway_seconds,
clear_cache=clear_cache,
mem_cache_only=mem_cache_only,
audience=f"https://{cdf_cluster}.cognitedata.com",
)


class OAuthInteractive(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache):
"""OAuth credential provider for an interactive login flow.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "cognite-sdk"

version = "7.62.0"
version = "7.62.1"
description = "Cognite Python SDK"
readme = "README.md"
documentation = "https://cognite-sdk-python.readthedocs-hosted.com"
Expand Down
Loading

0 comments on commit d9631cb

Please sign in to comment.