From 3b15f9ebe14c49fe29b31bf8b0744c4c863bfb6c Mon Sep 17 00:00:00 2001 From: Augustin Date: Tue, 29 Nov 2022 14:51:01 +0100 Subject: [PATCH] CDK: Emit control message on config mutation (#19428) --- airbyte-cdk/python/CHANGELOG.md | 3 + .../python/airbyte_cdk/config_observation.py | 76 +++++++++++ .../http/requests_native_auth/__init__.py | 10 +- .../requests_native_auth/abstract_oauth.py | 9 +- .../http/requests_native_auth/oauth.py | 128 +++++++++++++++++- airbyte-cdk/python/setup.py | 2 +- .../test_requests_native_auth.py | 71 ++++++++++ .../unit_tests/test_config_observation.py | 76 +++++++++++ 8 files changed, 368 insertions(+), 7 deletions(-) create mode 100644 airbyte-cdk/python/airbyte_cdk/config_observation.py create mode 100644 airbyte-cdk/python/unit_tests/test_config_observation.py diff --git a/airbyte-cdk/python/CHANGELOG.md b/airbyte-cdk/python/CHANGELOG.md index 2d011511854f..29de1cdbb2a9 100644 --- a/airbyte-cdk/python/CHANGELOG.md +++ b/airbyte-cdk/python/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.11.0 +Declare a new authenticator `SingleUseRefreshTokenOauth2Authenticator` that can perform connector configuration mutation and emit `AirbyteControlMessage.ConnectorConfig`. + ## 0.10.0 Low-code: Add `start_from_page` option to a PageIncrement class diff --git a/airbyte-cdk/python/airbyte_cdk/config_observation.py b/airbyte-cdk/python/airbyte_cdk/config_observation.py new file mode 100644 index 000000000000..8d886e44bed9 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/config_observation.py @@ -0,0 +1,76 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from __future__ import ( # Used to evaluate type hints at runtime, a NameError: name 'ConfigObserver' is not defined is thrown otherwise + annotations, +) + +import time +from typing import Any, List, MutableMapping + +from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, OrchestratorType, Type + + +class ObservedDict(dict): + def __init__(self, non_observed_mapping: MutableMapping, observer: ConfigObserver, update_on_unchanged_value=True) -> None: + non_observed_mapping = non_observed_mapping.copy() + self.observer = observer + self.update_on_unchanged_value = update_on_unchanged_value + for item, value in non_observed_mapping.items(): + # Observe nested dicts + if isinstance(value, MutableMapping): + non_observed_mapping[item] = ObservedDict(value, observer) + + # Observe nested list of dicts + if isinstance(value, List): + for i, sub_value in enumerate(value): + if isinstance(sub_value, MutableMapping): + value[i] = ObservedDict(sub_value, observer) + super().__init__(non_observed_mapping) + + def __setitem__(self, item: Any, value: Any): + """Override dict.__setitem__ by: + 1. Observing the new value if it is a dict + 2. Call observer update if the new value is different from the previous one + """ + previous_value = self.get(item) + if isinstance(value, MutableMapping): + value = ObservedDict(value, self.observer) + if isinstance(value, List): + for i, sub_value in enumerate(value): + if isinstance(sub_value, MutableMapping): + value[i] = ObservedDict(sub_value, self.observer) + super(ObservedDict, self).__setitem__(item, value) + if self.update_on_unchanged_value or value != previous_value: + self.observer.update() + + +class ConfigObserver: + """This class is made to track mutations on ObservedDict config. + When update is called a CONNECTOR_CONFIG control message is emitted on stdout. + """ + + def set_config(self, config: ObservedDict) -> None: + self.config = config + + def update(self) -> None: + self._emit_airbyte_control_message() + + def _emit_airbyte_control_message(self) -> None: + control_message = AirbyteControlMessage( + type=OrchestratorType.CONNECTOR_CONFIG, + emitted_at=time.time() * 1000, + connectorConfig=AirbyteControlConnectorConfigMessage(config=self.config), + ) + airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message) + print(airbyte_message.json(exclude_unset=True)) + + +def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]): + if isinstance(non_observed_connector_config, ObservedDict): + raise ValueError("This connector configuration is already observed") + connector_config_observer = ConfigObserver() + observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer) + connector_config_observer.set_config(observed_connector_config) + return observed_connector_config diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py index c4f64a971ea0..c336ef2b50e3 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py @@ -2,7 +2,13 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # -from .oauth import Oauth2Authenticator +from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator -__all__ = ["Oauth2Authenticator", "TokenAuthenticator", "MultipleTokenAuthenticator", "BasicHttpAuthenticator"] +__all__ = [ + "Oauth2Authenticator", + "SingleUseRefreshTokenOauth2Authenticator", + "TokenAuthenticator", + "MultipleTokenAuthenticator", + "BasicHttpAuthenticator", +] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index e7e0ce397e80..fab826bd7b46 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -64,6 +64,11 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: return payload + def _get_refresh_access_token_response(self): + response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body()) + response.raise_for_status() + return response.json() + def refresh_access_token(self) -> Tuple[str, int]: """ Returns the refresh token and its lifespan in seconds @@ -71,9 +76,7 @@ def refresh_access_token(self) -> Tuple[str, int]: :return: a tuple of (access_token, token_lifespan_in_seconds) """ try: - response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body()) - response.raise_for_status() - response_json = response.json() + response_json = self._get_refresh_access_token_response() return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()] except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 5f2e21df8841..5be3b3d05d16 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -2,9 +2,11 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # -from typing import Any, List, Mapping +from typing import Any, List, Mapping, Sequence, Tuple +import dpath import pendulum +from airbyte_cdk.config_observation import observe_connector_config from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator @@ -12,6 +14,7 @@ class Oauth2Authenticator(AbstractOauth2Authenticator): """ Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials. The generated access token is attached to each request via the Authorization header. + If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage. """ def __init__( @@ -80,3 +83,126 @@ def access_token(self) -> str: @access_token.setter def access_token(self, value: str): self._access_token = value + + +class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator): + """ + Authenticator that should be used for API implementing single use refresh tokens: + when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow. + This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation. + By default this authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token. + This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, client_secret_config_path, refresh_token_config_path constructor arguments. + """ + + def __init__( + self, + connector_config: Mapping[str, Any], + token_refresh_endpoint: str, + scopes: List[str] = None, + token_expiry_date: pendulum.DateTime = None, + access_token_name: str = "access_token", + expires_in_name: str = "expires_in", + refresh_token_name: str = "refresh_token", + refresh_request_body: Mapping[str, Any] = None, + grant_type: str = "refresh_token", + client_id_config_path: Sequence[str] = ("credentials", "client_id"), + client_secret_config_path: Sequence[str] = ("credentials", "client_secret"), + refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"), + ): + """ + + Args: + connector_config (Mapping[str, Any]): The full connector configuration + token_refresh_endpoint (str): Full URL to the token refresh endpoint + scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None. + token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None. + access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token". + expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". + refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". + refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None. + grant_type (str, optional): OAuth grant type. Defaults to "refresh_token". + client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id"). + client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret"). + refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token"). + """ + self._client_id_config_path = client_id_config_path + self._client_secret_config_path = client_secret_config_path + self._refresh_token_config_path = refresh_token_config_path + self._refresh_token_name = refresh_token_name + self._connector_config = observe_connector_config(connector_config) + self._validate_connector_config() + super().__init__( + token_refresh_endpoint, + self.get_client_id(), + self.get_client_secret(), + self.get_refresh_token(), + scopes, + token_expiry_date, + access_token_name, + expires_in_name, + refresh_request_body, + grant_type, + ) + + def _validate_connector_config(self): + """Validates the defined getters for configuration values are returning values. + + Raises: + ValueError: Raised if the defined getters are not returning a value. + """ + for field_path, getter, parameter_name in [ + (self._client_id_config_path, self.get_client_id, "client_id_config_path"), + (self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"), + (self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"), + ]: + try: + assert getter() + except KeyError: + raise ValueError( + f"This authenticator expects a value under the {field_path} field path. Please check your configuration structure or change the {parameter_name} value at initialization of this authenticator." + ) + + def get_refresh_token_name(self) -> str: + return self._refresh_token_name + + def get_client_id(self) -> str: + return dpath.util.get(self._connector_config, self._client_id_config_path) + + def get_client_secret(self) -> str: + return dpath.util.get(self._connector_config, self._client_secret_config_path) + + def get_refresh_token(self) -> str: + return dpath.util.get(self._connector_config, self._refresh_token_config_path) + + def set_refresh_token(self, new_refresh_token: str): + """Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message. + + Args: + new_refresh_token (str): The new refresh token value. + """ + dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token) + + def get_access_token(self) -> str: + """Retrieve new access and refresh token if the access token has expired. + The new refresh token is persisted with the set_refresh_token function + Returns: + str: The current access_token, updated if it was previously expired. + """ + if self.token_has_expired(): + t0 = pendulum.now() + new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token() + self.access_token = new_access_token + self.set_token_expiry_date(t0.add(seconds=access_token_expires_in)) + self.set_refresh_token(new_refresh_token) + return self.access_token + + def refresh_access_token(self) -> Tuple[str, int, str]: + try: + response_json = self._get_refresh_access_token_response() + return ( + response_json[self.get_access_token_name()], + response_json[self.get_expires_in_name()], + response_json[self.get_refresh_token_name()], + ) + except Exception as e: + raise Exception(f"Error while refreshing access token and refresh token: {e}") from e diff --git a/airbyte-cdk/python/setup.py b/airbyte-cdk/python/setup.py index 133323cb4827..191b93f5cb3c 100644 --- a/airbyte-cdk/python/setup.py +++ b/airbyte-cdk/python/setup.py @@ -15,7 +15,7 @@ setup( name="airbyte-cdk", - version="0.10.0", + version="0.11.0", description="A framework for writing Airbyte Connectors.", long_description=README, long_description_content_type="text/markdown", diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 97fc2d9e283d..368669bf1223 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -2,14 +2,18 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # +import json import logging import pendulum +import pytest import requests +from airbyte_cdk.config_observation import ObservedDict from airbyte_cdk.sources.streams.http.requests_native_auth import ( BasicHttpAuthenticator, MultipleTokenAuthenticator, Oauth2Authenticator, + SingleUseRefreshTokenOauth2Authenticator, TokenAuthenticator, ) from requests import Response @@ -175,6 +179,73 @@ def test_auth_call_method(self, mocker): assert {"Authorization": "Bearer access_token"} == prepared_request.headers +class TestSingleUseRefreshTokenOauth2Authenticator: + @pytest.fixture + def connector_config(self): + return { + "credentials": { + "access_token": "my_access_token", + "refresh_token": "my_refresh_token", + "client_id": "my_client_id", + "client_secret": "my_client_secret", + } + } + + @pytest.fixture + def invalid_connector_config(self): + return {"no_credentials_key": "foo"} + + def test_init(self, connector_config): + authenticator = SingleUseRefreshTokenOauth2Authenticator( + connector_config, + token_refresh_endpoint="foobar", + ) + assert isinstance(authenticator._connector_config, ObservedDict) + + def test_init_with_invalid_config(self, invalid_connector_config): + with pytest.raises(ValueError): + SingleUseRefreshTokenOauth2Authenticator( + invalid_connector_config, + token_refresh_endpoint="foobar", + ) + + def test_get_access_token(self, capsys, mocker, connector_config): + authenticator = SingleUseRefreshTokenOauth2Authenticator( + connector_config, + token_refresh_endpoint="foobar", + ) + authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", 42, "new_refresh_token")) + authenticator.token_has_expired = mocker.Mock(return_value=True) + access_token = authenticator.get_access_token() + captured = capsys.readouterr() + airbyte_message = json.loads(captured.out) + expected_new_config = connector_config.copy() + expected_new_config["credentials"]["refresh_token"] = "new_refresh_token" + assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config + assert authenticator.access_token == access_token == "new_access_token" + assert authenticator.get_refresh_token() == "new_refresh_token" + assert authenticator.get_token_expiry_date() > pendulum.now() + authenticator.token_has_expired = mocker.Mock(return_value=False) + access_token = authenticator.get_access_token() + captured = capsys.readouterr() + assert not captured.out + assert authenticator.access_token == access_token == "new_access_token" + + def test_refresh_access_token(self, mocker, connector_config): + authenticator = SingleUseRefreshTokenOauth2Authenticator( + connector_config, + token_refresh_endpoint="foobar", + ) + authenticator._get_refresh_access_token_response = mocker.Mock( + return_value={ + authenticator.get_access_token_name(): "new_access_token", + authenticator.get_expires_in_name(): 42, + authenticator.get_refresh_token_name(): "new_refresh_token", + } + ) + assert authenticator.refresh_access_token() == ("new_access_token", 42, "new_refresh_token") + + def mock_request(method, url, data): if url == "refresh_end": return resp diff --git a/airbyte-cdk/python/unit_tests/test_config_observation.py b/airbyte-cdk/python/unit_tests/test_config_observation.py new file mode 100644 index 000000000000..38dc2281aece --- /dev/null +++ b/airbyte-cdk/python/unit_tests/test_config_observation.py @@ -0,0 +1,76 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +import json +import time + +import pytest +from airbyte_cdk.config_observation import ConfigObserver, ObservedDict, observe_connector_config + + +class TestObservedDict: + def test_update_called_on_set_item(self, mocker): + mock_observer = mocker.Mock() + my_observed_dict = ObservedDict( + {"key": "value", "nested_dict": {"key": "value"}, "list_of_dict": [{"key": "value"}, {"key": "value"}]}, mock_observer + ) + assert mock_observer.update.call_count == 0 + + my_observed_dict["nested_dict"]["key"] = "new_value" + assert mock_observer.update.call_count == 1 + + # Setting the same value again should call observer's update + my_observed_dict["key"] = "new_value" + assert mock_observer.update.call_count == 2 + + my_observed_dict["nested_dict"]["new_key"] = "value" + assert mock_observer.update.call_count == 3 + + my_observed_dict["list_of_dict"][0]["key"] = "new_value" + assert mock_observer.update.call_count == 4 + + my_observed_dict["list_of_dict"][0]["new_key"] = "new_value" + assert mock_observer.update.call_count == 5 + + my_observed_dict["new_list_of_dicts"] = [{"foo": "bar"}] + assert mock_observer.update.call_count == 6 + + my_observed_dict["new_list_of_dicts"][0]["new_key"] = "new_value" + assert mock_observer.update.call_count == 7 + + +class TestConfigObserver: + def test_update(self, capsys): + config_observer = ConfigObserver() + config_observer.set_config(ObservedDict({"key": "value"}, config_observer)) + before_time = time.time() * 1000 + config_observer.update() + after_time = time.time() * 1000 + captured = capsys.readouterr() + airbyte_message = json.loads(captured.out) + assert airbyte_message["type"] == "CONTROL" + assert "control" in airbyte_message + raw_control_message = airbyte_message["control"] + assert raw_control_message["type"] == "CONNECTOR_CONFIG" + assert raw_control_message["connectorConfig"] == {"config": dict(config_observer.config)} + assert before_time < raw_control_message["emitted_at"] < after_time + + +def test_observe_connector_config(capsys): + non_observed_config = {"foo": "bar"} + observed_config = observe_connector_config(non_observed_config) + observer = observed_config.observer + assert isinstance(observed_config, ObservedDict) + assert isinstance(observer, ConfigObserver) + assert observed_config.observer.config == observed_config + observed_config["foo"] = "foo" + captured = capsys.readouterr() + airbyte_message = json.loads(captured.out) + assert airbyte_message["control"]["connectorConfig"] == {"config": {"foo": "foo"}} + + +def test_observe_already_observed_config(): + observed_config = observe_connector_config({"foo": "bar"}) + with pytest.raises(ValueError): + observe_connector_config(observed_config)