Skip to content

Commit

Permalink
CDK: Emit control message on config mutation (#19428)
Browse files Browse the repository at this point in the history
  • Loading branch information
alafanechere authored Nov 29, 2022
1 parent b78fdad commit 3b15f9e
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 7 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.11.0
Declare a new authenticator `SingleUseRefreshTokenOauth2Authenticator` that can perform connector configuration mutation and emit `AirbyteControlMessage.ConnectorConfig`.

## 0.10.0
Low-code: Add `start_from_page` option to a PageIncrement class

Expand Down
76 changes: 76 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from __future__ import ( # Used to evaluate type hints at runtime, a NameError: name 'ConfigObserver' is not defined is thrown otherwise
annotations,
)

import time
from typing import Any, List, MutableMapping

from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, OrchestratorType, Type


class ObservedDict(dict):
def __init__(self, non_observed_mapping: MutableMapping, observer: ConfigObserver, update_on_unchanged_value=True) -> None:
non_observed_mapping = non_observed_mapping.copy()
self.observer = observer
self.update_on_unchanged_value = update_on_unchanged_value
for item, value in non_observed_mapping.items():
# Observe nested dicts
if isinstance(value, MutableMapping):
non_observed_mapping[item] = ObservedDict(value, observer)

# Observe nested list of dicts
if isinstance(value, List):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, observer)
super().__init__(non_observed_mapping)

def __setitem__(self, item: Any, value: Any):
"""Override dict.__setitem__ by:
1. Observing the new value if it is a dict
2. Call observer update if the new value is different from the previous one
"""
previous_value = self.get(item)
if isinstance(value, MutableMapping):
value = ObservedDict(value, self.observer)
if isinstance(value, List):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, self.observer)
super(ObservedDict, self).__setitem__(item, value)
if self.update_on_unchanged_value or value != previous_value:
self.observer.update()


class ConfigObserver:
"""This class is made to track mutations on ObservedDict config.
When update is called a CONNECTOR_CONFIG control message is emitted on stdout.
"""

def set_config(self, config: ObservedDict) -> None:
self.config = config

def update(self) -> None:
self._emit_airbyte_control_message()

def _emit_airbyte_control_message(self) -> None:
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=self.config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))


def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]):
if isinstance(non_observed_connector_config, ObservedDict):
raise ValueError("This connector configuration is already observed")
connector_config_observer = ConfigObserver()
observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer)
connector_config_observer.set_config(observed_connector_config)
return observed_connector_config
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from .oauth import Oauth2Authenticator
from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator

__all__ = ["Oauth2Authenticator", "TokenAuthenticator", "MultipleTokenAuthenticator", "BasicHttpAuthenticator"]
__all__ = [
"Oauth2Authenticator",
"SingleUseRefreshTokenOauth2Authenticator",
"TokenAuthenticator",
"MultipleTokenAuthenticator",
"BasicHttpAuthenticator",
]
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def _get_refresh_access_token_response(self):
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()

def refresh_access_token(self) -> Tuple[str, int]:
"""
Returns the refresh token and its lifespan in seconds
:return: a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
response_json = response.json()
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping
from typing import Any, List, Mapping, Sequence, Tuple

import dpath
import pendulum
from airbyte_cdk.config_observation import observe_connector_config
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator


class Oauth2Authenticator(AbstractOauth2Authenticator):
"""
Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials.
The generated access token is attached to each request via the Authorization header.
If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage.
"""

def __init__(
Expand Down Expand Up @@ -80,3 +83,126 @@ def access_token(self) -> str:
@access_token.setter
def access_token(self, value: str):
self._access_token = value


class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
"""
Authenticator that should be used for API implementing single use refresh tokens:
when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow.
This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation.
By default this authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token.
This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, client_secret_config_path, refresh_token_config_path constructor arguments.
"""

def __init__(
self,
connector_config: Mapping[str, Any],
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
client_id_config_path: Sequence[str] = ("credentials", "client_id"),
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
):
"""
Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None.
token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None.
access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token".
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id").
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._refresh_token_config_path = refresh_token_config_path
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
)

def _validate_connector_config(self):
"""Validates the defined getters for configuration values are returning values.
Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
for field_path, getter, parameter_name in [
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
]:
try:
assert getter()
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {field_path} field path. Please check your configuration structure or change the {parameter_name} value at initialization of this authenticator."
)

def get_refresh_token_name(self) -> str:
return self._refresh_token_name

def get_client_id(self) -> str:
return dpath.util.get(self._connector_config, self._client_id_config_path)

def get_client_secret(self) -> str:
return dpath.util.get(self._connector_config, self._client_secret_config_path)

def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)

def set_refresh_token(self, new_refresh_token: str):
"""Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message.
Args:
new_refresh_token (str): The new refresh token value.
"""
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
self.access_token = new_access_token
self.set_token_expiry_date(t0.add(seconds=access_token_expires_in))
self.set_refresh_token(new_refresh_token)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
try:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
except Exception as e:
raise Exception(f"Error while refreshing access token and refresh token: {e}") from e
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.10.0",
version="0.11.0",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import json
import logging

import pendulum
import pytest
import requests
from airbyte_cdk.config_observation import ObservedDict
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Oauth2Authenticator,
SingleUseRefreshTokenOauth2Authenticator,
TokenAuthenticator,
)
from requests import Response
Expand Down Expand Up @@ -175,6 +179,73 @@ def test_auth_call_method(self, mocker):
assert {"Authorization": "Bearer access_token"} == prepared_request.headers


class TestSingleUseRefreshTokenOauth2Authenticator:
@pytest.fixture
def connector_config(self):
return {
"credentials": {
"access_token": "my_access_token",
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
}
}

@pytest.fixture
def invalid_connector_config(self):
return {"no_credentials_key": "foo"}

def test_init(self, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
assert isinstance(authenticator._connector_config, ObservedDict)

def test_init_with_invalid_config(self, invalid_connector_config):
with pytest.raises(ValueError):
SingleUseRefreshTokenOauth2Authenticator(
invalid_connector_config,
token_refresh_endpoint="foobar",
)

def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", 42, "new_refresh_token"))
authenticator.token_has_expired = mocker.Mock(return_value=True)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
airbyte_message = json.loads(captured.out)
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
assert authenticator.get_token_expiry_date() > pendulum.now()
authenticator.token_has_expired = mocker.Mock(return_value=False)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
assert not captured.out
assert authenticator.access_token == access_token == "new_access_token"

def test_refresh_access_token(self, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator._get_refresh_access_token_response = mocker.Mock(
return_value={
authenticator.get_access_token_name(): "new_access_token",
authenticator.get_expires_in_name(): 42,
authenticator.get_refresh_token_name(): "new_refresh_token",
}
)
assert authenticator.refresh_access_token() == ("new_access_token", 42, "new_refresh_token")


def mock_request(method, url, data):
if url == "refresh_end":
return resp
Expand Down
Loading

0 comments on commit 3b15f9e

Please sign in to comment.