Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDK: Emit control message on config mutation #19428

Merged
merged 30 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3d4b5a
wip
alafanechere Nov 14, 2022
24d24dd
implementation
alafanechere Nov 15, 2022
355b90b
format
alafanechere Nov 15, 2022
61e9d4b
bump version
alafanechere Nov 15, 2022
1101ccc
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 15, 2022
22f75c7
always update by default, even if same value
alafanechere Nov 16, 2022
d562117
rename split_config to filter_internal_keywords
alafanechere Nov 16, 2022
733b89a
bing ads example
alafanechere Nov 16, 2022
a59a1e6
wrap around AirbyteMessage
alafanechere Nov 16, 2022
616256e
exclude unset
alafanechere Nov 16, 2022
3ef545d
observer does not write config to disk
alafanechere Nov 21, 2022
0ff7318
revert global changes
alafanechere Nov 21, 2022
9e1483c
revert global changes
alafanechere Nov 21, 2022
2bbff33
revert global changes
alafanechere Nov 21, 2022
1f02860
observe from Oauth2Authenticator
alafanechere Nov 21, 2022
f40d99c
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 21, 2022
04041d8
ref
alafanechere Nov 21, 2022
12db1b7
handle list of dicts
alafanechere Nov 21, 2022
63d129c
implement SingleUseRefreshTokenOauth2Authenticator
alafanechere Nov 23, 2022
62ef553
test SingleUseRefreshTokenOauth2Authenticator
alafanechere Nov 23, 2022
bbe802a
call copy in ObservedDict
alafanechere Nov 23, 2022
8fd52f9
add docstring
alafanechere Nov 23, 2022
c5a0b6d
source harvest example
alafanechere Nov 23, 2022
803b70a
use dpath
alafanechere Nov 25, 2022
37f474f
better doc string
alafanechere Nov 25, 2022
08f64e4
update changelog
alafanechere Nov 25, 2022
518a36f
use sequence instead of string path for dpath declaration
alafanechere Nov 29, 2022
64bbc48
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 29, 2022
4da6581
revert connector changes
alafanechere Nov 29, 2022
380e175
format
alafanechere Nov 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class ObservedDict(dict):
def __init__(self, non_observed_mapping: MutableMapping, observer: ConfigObserver, update_on_unchanged_value=True) -> None:
non_observed_mapping = non_observed_mapping.copy()
self.observer = observer
self.update_on_unchanged_value = update_on_unchanged_value
for item, value in non_observed_mapping.items():
Expand Down Expand Up @@ -46,7 +47,7 @@ def __setitem__(self, item: Any, value: Any):

class ConfigObserver:
"""This class is made to track mutations on ObservedDict config.
When update is called the observed configuration is saved on disk a CONNECTOR_CONFIG control message is emitted on stdout.
When update is called a CONNECTOR_CONFIG control message is emitted on stdout.
"""

def set_config(self, config: ObservedDict) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from .oauth import Oauth2Authenticator
from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator

__all__ = ["Oauth2Authenticator", "TokenAuthenticator", "MultipleTokenAuthenticator", "BasicHttpAuthenticator"]
__all__ = [
"Oauth2Authenticator",
"SingleUseRefreshTokenOauth2Authenticator",
"TokenAuthenticator",
"MultipleTokenAuthenticator",
"BasicHttpAuthenticator",
]
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def _get_refresh_access_token_response(self):
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()

def refresh_access_token(self) -> Tuple[str, int]:
"""
Returns the refresh token and its lifespan in seconds

:return: a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
response_json = response.json()
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping
from typing import Any, List, Mapping, Tuple

import pendulum
from airbyte_cdk.config_observation import observe_connector_config
Expand All @@ -28,7 +28,6 @@ def __init__(
expires_in_name: str = "expires_in",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
connector_config: Mapping[str, Any] = None,
):
self._token_refresh_endpoint = token_refresh_endpoint
self._client_secret = client_secret
Expand All @@ -42,7 +41,6 @@ def __init__(

self._token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1)
self._access_token = None
self._connector_config = observe_connector_config(connector_config) if connector_config else None

def get_token_refresh_endpoint(self) -> str:
return self._token_refresh_endpoint
Expand Down Expand Up @@ -84,3 +82,109 @@ def access_token(self) -> str:
@access_token.setter
def access_token(self, value: str):
self._access_token = value


class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
"""
Authenticator that should be used for API implementing single use refresh tokens:
when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow.
This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation.
This authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token.
alafanechere marked this conversation as resolved.
Show resolved Hide resolved
This behavior can be changed by overriding getters or changing the default "credentials_configuration_field_name" value.
"""

def __init__(
self,
connector_config: Mapping[str, Any],
alafanechere marked this conversation as resolved.
Show resolved Hide resolved
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
credentials_configuration_field_name: str = "credentials",
):
self.credentials_configuration_field_name = credentials_configuration_field_name
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
)

def _validate_connector_config(self):
"""Validates the defined getters for configuration values are returning values.

Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
for field_name, getter in [
("client_id", self.get_client_id),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we care about retaining and validating client_id and client_secret in this class instead of just passing up the concern to the super?

Copy link
Contributor Author

@alafanechere alafanechere Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As client_id, client_secret, and refresh_token are mandatory arguments of the parent class, and the current constructor signature of the new class only takes a connector_config dict, I thought it would be safer to fail early if these fields' values could not be parsed from the config.

("client_secret", self.get_client_secret),
(self.get_refresh_token_name(), self.get_refresh_token),
]:
try:
assert getter()
except (AssertionError, KeyError):
raise ValueError(
f"This authenticator expects a {field_name} field under the {self.credentials_configuration_field_name} field. Please override this class getters or change your configuration structure."
)

def get_refresh_token_name(self) -> str:
return self._refresh_token_name

def _get_config_credentials_field(self, field_name):
alafanechere marked this conversation as resolved.
Show resolved Hide resolved
return self._connector_config[self.credentials_configuration_field_name][field_name]
sherifnada marked this conversation as resolved.
Show resolved Hide resolved

def get_client_id(self) -> str:
return self._get_config_credentials_field("client_id")

def get_client_secret(self) -> str:
return self._get_config_credentials_field("client_secret")

def set_refresh_token(self, new_refresh_token: str):
"""Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message.

Args:
new_refresh_token (str): The new refresh token value.
"""
self._connector_config[self.credentials_configuration_field_name][self.get_refresh_token_name()] = new_refresh_token

def get_refresh_token(self) -> str:
return self._get_config_credentials_field(self.get_refresh_token_name())

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
self.access_token = new_access_token
self.set_token_expiry_date(t0.add(seconds=access_token_expires_in))
self.set_refresh_token(new_refresh_token)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
try:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would recommend using dpath

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed but for implementation consistency with parent class and abstract class, I'd prefer to do this in a separate PR in which we can replace the access_token_name, expires_in_name, and refresh_token_name by dpaths that can be used when parsing these responses. Wdyt?

Copy link
Contributor

@sherifnada sherifnada Nov 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense!

response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
except Exception as e:
raise Exception(f"Error while refreshing access token and refresh token: {e}") from e
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import logging

import pendulum
import pytest
import requests
from airbyte_cdk.config_observation import ObservedDict
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Oauth2Authenticator,
SingleUseRefreshTokenOauth2Authenticator,
TokenAuthenticator,
)
from requests import Response
Expand Down Expand Up @@ -175,20 +178,72 @@ def test_auth_call_method(self, mocker):

assert {"Authorization": "Bearer access_token"} == prepared_request.headers

def test_auth_with_config_mutation(self, capsys):
original_connector_config = {"refresh_token": "foo"}
oauth = Oauth2Authenticator(
token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint,
client_id=TestOauth2Authenticator.client_id,
client_secret=TestOauth2Authenticator.client_secret,
refresh_token=TestOauth2Authenticator.refresh_token,
connector_config=original_connector_config,

class TestSingleUseRefreshTokenOauth2Authenticator:
@pytest.fixture
def connector_config(self):
return {
"credentials": {
"access_token": "my_access_token",
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
}
}

@pytest.fixture
def invalid_connector_config(self):
return {"no_credentials_key": "foo"}

def test_init(self, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
oauth._connector_config["refresh_token"] = "bar"
assert isinstance(authenticator._connector_config, ObservedDict)

def test_init_with_invalid_config(self, invalid_connector_config):
with pytest.raises(ValueError):
SingleUseRefreshTokenOauth2Authenticator(
invalid_connector_config,
token_refresh_endpoint="foobar",
)

def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", 42, "new_refresh_token"))
authenticator.token_has_expired = mocker.Mock(return_value=True)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
airbyte_message = json.loads(captured.out)
assert airbyte_message["control"]["connectorConfig"] == {"config": {"refresh_token": "bar"}}
assert original_connector_config["refresh_token"] == "foo"
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
assert authenticator.get_token_expiry_date() > pendulum.now()
authenticator.token_has_expired = mocker.Mock(return_value=False)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
assert not captured.out
assert authenticator.access_token == access_token == "new_access_token"

def test_refresh_access_token(self, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator._get_refresh_access_token_response = mocker.Mock(
return_value={
authenticator.get_access_token_name(): "new_access_token",
authenticator.get_expires_in_name(): 42,
authenticator.get_refresh_token_name(): "new_refresh_token",
}
)
assert authenticator.refresh_access_token() == ("new_access_token", 42, "new_refresh_token")


def mock_request(method, url, data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any, Mapping

from airbyte_cdk.sources.streams.http.auth import Oauth2Authenticator, TokenAuthenticator
from airbyte_cdk.sources.streams.http.requests_native_auth import SingleUseRefreshTokenOauth2Authenticator, TokenAuthenticator


class HarvestMixin:
Expand All @@ -29,8 +29,11 @@ class HarvestTokenAuthenticator(HarvestMixin, TokenAuthenticator):
"""


class HarvestOauth2Authenticator(HarvestMixin, Oauth2Authenticator):
class HarvestOauth2Authenticator(SingleUseRefreshTokenOauth2Authenticator):
"""
Auth class for OAuth2
https://help.getharvest.com/api-v2/authentication-api/authentication/authentication/#for-server-side-applications
"""

def get_auth_header(self) -> Mapping[str, Any]:
return {**super().get_auth_header(), "Harvest-Account-ID": self._connector_config["account_id"]}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ def get_authenticator(config):
credentials = config.get("credentials", {})
if credentials and "client_id" in credentials:
return HarvestOauth2Authenticator(
config,
token_refresh_endpoint="https://id.getharvest.com/api/v2/oauth2/token",
client_id=credentials.get("client_id"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were these removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client_id parameter does not exist in my SingleUseRefreshTokenOauth2Authenticator implementation. The client_id is retrieved from the configuration with dpath

client_secret=credentials.get("client_secret"),
refresh_token=credentials.get("refresh_token"),
account_id=config["account_id"],
)

api_token = credentials.get("api_token", config.get("api_token"))
Expand Down