Skip to content

Commit

Permalink
[airbyte-cdk] Decouple request_options_provider from datetime_based_c…
Browse files Browse the repository at this point in the history
…ursor + concurrent_cursor features for low-code (#45413)
  • Loading branch information
brianjlai authored Sep 17, 2024
1 parent 32b8648 commit 199a807
Show file tree
Hide file tree
Showing 12 changed files with 1,030 additions and 45 deletions.
2 changes: 2 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from .sources.declarative.requesters.request_option import RequestOption, RequestOptionType

from .sources.declarative.requesters.request_options.default_request_options_provider import DefaultRequestOptionsProvider
from .sources.declarative.requesters.request_options.interpolated_request_input_provider import InterpolatedRequestInputProvider
from .sources.declarative.requesters.requester import HttpMethod
from .sources.declarative.retrievers import SimpleRetriever
Expand Down Expand Up @@ -133,6 +134,7 @@
"DeclarativeStream",
"Decoder",
"DefaultPaginator",
"DefaultRequestOptionsProvider",
"DpathExtractor",
"FieldPointer",
"HttpMethod",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@
StopConditionPaginationStrategyDecorator,
)
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType
from airbyte_cdk.sources.declarative.requesters.request_options import InterpolatedRequestOptionsProvider
from airbyte_cdk.sources.declarative.requesters.request_options import (
DatetimeBasedRequestOptionsProvider,
DefaultRequestOptionsProvider,
InterpolatedRequestOptionsProvider,
RequestOptionsProvider,
)
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers import AsyncRetriever, SimpleRetriever, SimpleRetrieverTestReadDecorator
Expand Down Expand Up @@ -653,6 +658,40 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi
"per_partition_cursor": combined_slicers if isinstance(combined_slicers, PerPartitionCursor) else None,
"is_global_substream_cursor": isinstance(combined_slicers, GlobalSubstreamCursor),
}

if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel):
cursor_model = model.incremental_sync

end_time_option = (
RequestOption(
inject_into=RequestOptionType(cursor_model.end_time_option.inject_into.value),
field_name=cursor_model.end_time_option.field_name,
parameters=cursor_model.parameters or {},
)
if cursor_model.end_time_option
else None
)
start_time_option = (
RequestOption(
inject_into=RequestOptionType(cursor_model.start_time_option.inject_into.value),
field_name=cursor_model.start_time_option.field_name,
parameters=cursor_model.parameters or {},
)
if cursor_model.start_time_option
else None
)

request_options_provider = DatetimeBasedRequestOptionsProvider(
start_time_option=start_time_option,
end_time_option=end_time_option,
partition_field_start=cursor_model.partition_field_end,
partition_field_end=cursor_model.partition_field_end,
config=config,
parameters=model.parameters or {},
)
else:
request_options_provider = None

transformations = []
if model.transformations:
for transformation_model in model.transformations:
Expand All @@ -663,6 +702,7 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi
name=model.name,
primary_key=primary_key,
stream_slicer=combined_slicers,
request_options_provider=request_options_provider,
stop_condition_on_cursor=stop_condition_on_cursor,
client_side_incremental_sync=client_side_incremental_sync,
transformations=transformations,
Expand Down Expand Up @@ -1126,6 +1166,7 @@ def create_simple_retriever(
name: str,
primary_key: Optional[Union[str, List[str], List[List[str]]]],
stream_slicer: Optional[StreamSlicer],
request_options_provider: Optional[RequestOptionsProvider] = None,
stop_condition_on_cursor: bool = False,
client_side_incremental_sync: Optional[Dict[str, Any]] = None,
transformations: List[RecordTransformation],
Expand All @@ -1140,11 +1181,21 @@ def create_simple_retriever(
client_side_incremental_sync=client_side_incremental_sync,
)
url_base = model.requester.url_base if hasattr(model.requester, "url_base") else requester.get_url_base()
stream_slicer = stream_slicer or SinglePartitionRouter(parameters={})

# Define cursor only if per partition or common incremental support is needed
cursor = stream_slicer if isinstance(stream_slicer, DeclarativeCursor) else None

if not isinstance(stream_slicer, DatetimeBasedCursor) or type(stream_slicer) is not DatetimeBasedCursor:
# Many of the custom component implementations of DatetimeBasedCursor override get_request_params() (or other methods).
# Because we're decoupling RequestOptionsProvider from the Cursor, custom components will eventually need to reimplement
# their own RequestOptionsProvider. However, right now the existing StreamSlicer/Cursor still can act as the SimpleRetriever's
# request_options_provider
request_options_provider = stream_slicer or DefaultRequestOptionsProvider(parameters={})
elif not request_options_provider:
request_options_provider = DefaultRequestOptionsProvider(parameters={})

stream_slicer = stream_slicer or SinglePartitionRouter(parameters={})

cursor_used_for_stop_condition = cursor if stop_condition_on_cursor else None
paginator = (
self._create_component_from_model(
Expand All @@ -1168,6 +1219,7 @@ def create_simple_retriever(
requester=requester,
record_selector=record_selector,
stream_slicer=stream_slicer,
request_option_provider=request_options_provider,
cursor=cursor,
config=config,
maximum_number_of_slices=self._limit_slices_fetched or 5,
Expand All @@ -1181,6 +1233,7 @@ def create_simple_retriever(
requester=requester,
record_selector=record_selector,
stream_slicer=stream_slicer,
request_option_provider=request_options_provider,
cursor=cursor,
config=config,
ignore_stream_slicer_parameters_on_paginated_requests=ignore_stream_slicer_parameters_on_paginated_requests,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.requesters.request_options.datetime_based_request_options_provider import (
DatetimeBasedRequestOptionsProvider,
)
from airbyte_cdk.sources.declarative.requesters.request_options.default_request_options_provider import DefaultRequestOptionsProvider
from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import (
InterpolatedRequestOptionsProvider,
)
from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider

__all__ = ["InterpolatedRequestOptionsProvider", "RequestOptionsProvider"]
__all__ = ["DatetimeBasedRequestOptionsProvider", "DefaultRequestOptionsProvider", "InterpolatedRequestOptionsProvider", "RequestOptionsProvider"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#

from dataclasses import InitVar, dataclass
from typing import Any, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState


@dataclass
class DatetimeBasedRequestOptionsProvider(RequestOptionsProvider):
"""
Request options provider that extracts fields from the stream_slice and injects them into the respective location in the
outbound request being made
"""

config: Config
parameters: InitVar[Mapping[str, Any]]
start_time_option: Optional[RequestOption] = None
end_time_option: Optional[RequestOption] = None
partition_field_start: Optional[str] = None
partition_field_end: Optional[str] = None

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)

def get_request_params(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter, stream_slice)

def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.header, stream_slice)

def get_request_body_data(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Union[Mapping[str, Any], str]:
return self._get_request_options(RequestOptionType.body_data, stream_slice)

def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json, stream_slice)

def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]:
options: MutableMapping[str, Any] = {}
if not stream_slice:
return options
if self.start_time_option and self.start_time_option.inject_into == option_type:
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string
self._partition_field_start.eval(self.config)
)
if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self._partition_field_end.eval(self.config)) # type: ignore # field_name is always casted to an interpolated string
return options
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#

from dataclasses import InitVar, dataclass
from typing import Any, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import RequestOptionsProvider
from airbyte_cdk.sources.types import StreamSlice, StreamState


@dataclass
class DefaultRequestOptionsProvider(RequestOptionsProvider):
"""
Request options provider that extracts fields from the stream_slice and injects them into the respective location in the
outbound request being made
"""

parameters: InitVar[Mapping[str, Any]]

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
pass

def get_request_params(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return {}

def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return {}

def get_request_body_data(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Union[Mapping[str, Any], str]:
return {}

def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return {}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import json
from dataclasses import InitVar, dataclass, field
from functools import partial
Expand All @@ -16,6 +17,7 @@
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter
from airbyte_cdk.sources.declarative.requesters.paginators.no_pagination import NoPagination
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
from airbyte_cdk.sources.declarative.requesters.request_options import DefaultRequestOptionsProvider, RequestOptionsProvider
from airbyte_cdk.sources.declarative.requesters.requester import Requester
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer
Expand Down Expand Up @@ -61,6 +63,7 @@ class SimpleRetriever(Retriever):
_primary_key: str = field(init=False, repr=False, default="")
paginator: Optional[Paginator] = None
stream_slicer: StreamSlicer = field(default_factory=lambda: SinglePartitionRouter(parameters={}))
request_option_provider: RequestOptionsProvider = field(default_factory=lambda: DefaultRequestOptionsProvider(parameters={}))
cursor: Optional[DeclarativeCursor] = None
ignore_stream_slicer_parameters_on_paginated_requests: bool = False

Expand Down Expand Up @@ -158,7 +161,7 @@ def _request_params(
stream_slice,
next_page_token,
self._paginator.get_request_params,
self.stream_slicer.get_request_params,
self.request_option_provider.get_request_params,
)
if isinstance(params, str):
raise ValueError("Request params cannot be a string")
Expand All @@ -184,7 +187,7 @@ def _request_body_data(
stream_slice,
next_page_token,
self._paginator.get_request_body_data,
self.stream_slicer.get_request_body_data,
self.request_option_provider.get_request_body_data,
)

def _request_body_json(
Expand All @@ -203,7 +206,7 @@ def _request_body_json(
stream_slice,
next_page_token,
self._paginator.get_request_body_json,
self.stream_slicer.get_request_body_json,
self.request_option_provider.get_request_body_json,
)
if isinstance(body_json, str):
raise ValueError("Request body json cannot be a string")
Expand Down Expand Up @@ -231,21 +234,21 @@ def _parse_response(
) -> Iterable[Record]:
if not response:
self._last_response = None
return []

self._last_response = response
record_generator = self.record_selector.select_records(
response=response,
stream_state=stream_state,
records_schema=records_schema,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
self._last_page_size = 0
for record in record_generator:
self._last_page_size += 1
self._last_record = record
yield record
yield from []
else:
self._last_response = response
record_generator = self.record_selector.select_records(
response=response,
stream_state=stream_state,
records_schema=records_schema,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
self._last_page_size = 0
for record in record_generator:
self._last_page_size += 1
self._last_record = record
yield record

@property # type: ignore
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
Expand Down
Loading

0 comments on commit 199a807

Please sign in to comment.