Skip to content

Commit

Permalink
Move state management from partition router to cursor classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tolik0 committed Oct 24, 2024
1 parent 1fb01a9 commit 2b38e0e
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import threading
import time
from typing import Any, Iterable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union

from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
Expand All @@ -14,22 +14,35 @@
T = TypeVar("T")


def iterate_with_last_flag(generator: Iterable[T]) -> Iterable[tuple[T, bool]]:
def iterate_with_last_flag_and_state(generator: Iterable[T], get_stream_state_func: Callable[[T], Any]) -> Iterable[tuple[T, bool, Any]]:
"""
Iterates over the given generator and returns a tuple containing the element and a flag
indicating whether it's the last element in the generator. If the generator is empty,
it returns an empty iterator.
Iterates over the given generator, yielding tuples containing the element, a flag
indicating whether it's the last element in the generator, and the result of
`get_stream_state_func` applied to the element.
Args:
generator: The iterable to iterate over.
get_stream_state_func: A function that takes an element from the generator and
returns its state.
Returns:
An iterator that yields tuples of the form (element, is_last, state).
"""

iterator = iter(generator)

try:
current = next(iterator)
state = get_stream_state_func()
except StopIteration:
return # Return an empty iterator

for next_item in iterator:
yield current, False
yield current, False, state
current = next_item
yield current, True
state = get_stream_state_func()

yield current, True, get_stream_state_func()


class Timer:
Expand Down Expand Up @@ -74,6 +87,7 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti
self._lookback_window: Optional[int] = None
self._current_partition: Optional[Mapping[str, Any]] = None
self._last_slice: bool = False
self._parent_state: Optional[Mapping[str, Any]] = None

def start_slices_generation(self) -> None:
self._timer.start()
Expand All @@ -100,12 +114,11 @@ def stream_slices(self) -> Iterable[StreamSlice]:
)

self.start_slices_generation()
for slice, last in iterate_with_last_flag(slice_generator):
self._current_partition = slice.partition
for slice, last, state in iterate_with_last_flag_and_state(slice_generator, self._partition_router.get_stream_state):
self._parent_state = state
self.register_slice(last)
yield slice
self._current_partition = None
self._last_slice = True
self._parent_state = self._partition_router.get_stream_state()

def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
slice_generator = (
Expand Down Expand Up @@ -210,11 +223,8 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = True) -> StreamState:
state: dict[str, Any] = {"state": self._stream_cursor.get_stream_state()}

parent_state = self._partition_router.get_stream_state(
partition=partition or self._current_partition, last=self._last_slice or last
)
if parent_state:
state["parent_state"] = parent_state
if self._parent_state:
state["parent_state"] = self._parent_state

if self._lookback_window is not None:
state["lookback_window"] = self._lookback_window
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()
self._current_partition: Optional[Mapping[str, Any]] = None
self._last_partition: bool = False

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
for partition in slices:
yield from self.generate_slices_from_partition(partition)
self._last_partition = True

def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
# Ensure the maximum number of partitions is not exceeded
Expand Down Expand Up @@ -154,7 +155,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
f"we should only update state for partitions that were emitted during `stream_slices`"
)

def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = True) -> StreamState:
def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = False) -> StreamState:
states = []
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor.get_stream_state()
Expand All @@ -167,7 +168,7 @@ def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last:
)
state: dict[str, Any] = {"states": states}

parent_state = self._partition_router.get_stream_state(partition=partition, last=last)
parent_state = self._partition_router.get_stream_state(last=last or self._last_partition)
if parent_state:
state["parent_state"] = parent_state
return state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import GlobalSubstreamCursor, iterate_with_last_flag
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import GlobalSubstreamCursor, iterate_with_last_flag_and_state
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import CursorFactory, PerPartitionCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
self._use_global_cursor = False
self._current_partition: Optional[Mapping[str, Any]] = None
self._last_slice: bool = False
self._parent_state: Optional[Mapping[str, Any]] = None

def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]:
return self._global_cursor if self._use_global_cursor else self._per_partition_cursor
Expand All @@ -75,23 +76,25 @@ def stream_slices(self) -> Iterable[StreamSlice]:
self._global_cursor.start_slices_generation()

# Iterate through partitions and process slices
for partition, is_last_partition in iterate_with_last_flag(self._partition_router.stream_slices()):
for partition, is_last_partition, parent_state in iterate_with_last_flag_and_state(self._partition_router.stream_slices(), self._partition_router.get_stream_state):
# Generate slices for the current cursor and handle the last slice using the flag
self._current_partition = partition.partition
for slice, is_last_slice in iterate_with_last_flag(
self._get_active_cursor().generate_slices_from_partition(partition=partition)
self._parent_state = parent_state
for slice, is_last_slice, _ in iterate_with_last_flag_and_state(
self._get_active_cursor().generate_slices_from_partition(partition=partition), lambda: None
):

self._global_cursor.register_slice(is_last_slice and is_last_partition)
yield slice
self._current_partition = None
self._last_slice = True
self._parent_state = self._partition_router.get_stream_state()

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the initial state for the cursors.
"""
self._use_global_cursor = stream_state.get("use_global_cursor", False)

self._parent_state = stream_state.get("parent_state", {})

self._global_cursor.set_initial_state(stream_state)
self._per_partition_cursor.set_initial_state(stream_state)

Expand All @@ -115,6 +118,9 @@ def get_stream_state(self) -> StreamState:
if not self._use_global_cursor:
final_state.update(self._per_partition_cursor.get_stream_state(partition=self._current_partition, last=self._last_slice))

if self._parent_state:
final_state["parent_state"] = self._parent_state

return final_state

def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class CartesianProductStreamSlicer(PartitionRouter):

stream_slicers: List[PartitionRouter]
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._initial_parent_state = {}

def get_request_params(
self,
Expand Down Expand Up @@ -136,15 +138,19 @@ def set_initial_state(self, stream_state: StreamState) -> None:
}
}
"""
if "parent_state" in stream_state:
self._initial_parent_state = stream_state["parent_state"]

for stream_slicer in self.stream_slicers:
stream_slicer.set_initial_state(stream_state)

def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = False) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, last: bool = False) -> Optional[Mapping[str, StreamState]]:
"""
Get the state of the parent streams.
This method returns the combined parent states from all stream slicers. If a stream slicer does not have parent streams,
this will be skipped due to the default StreamSlicer implementation.
This method returns the combined parent states from all stream slicers. It currently retrieves the final state only for the last partition processed. If a stream slicer does not have parent streams, this will be skipped due to the default StreamSlicer implementation.
TODO: Can be improved by tracking the state of every stream slicer and updating the state of the last stream slicer when all the partitions for other slicers have been produced.
Returns:
Optional[Mapping[str, StreamState]]: The current state of the parent streams in a dictionary format.
Expand All @@ -158,10 +164,13 @@ def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last:
}
}
"""
if not last:
return self._initial_parent_state

combined_state: dict[str, StreamState] = {}
for s in self.stream_slicers:
# Getting the initial state of the stream slicer
parent_state = s.get_stream_state(partition=None, last=last)
parent_state = s.get_stream_state(last=last)
if parent_state:
combined_state.update(parent_state)
return combined_state
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
pass

def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = False) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, last: bool = False) -> Optional[Mapping[str, StreamState]]:
"""
ListPartitionRouter doesn't have parent streams
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""

@abstractmethod
def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = False) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, last: bool = False) -> Optional[Mapping[str, StreamState]]:
"""
Get the state of the parent streams.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
pass

def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = False) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, last: bool = False) -> Optional[Mapping[str, StreamState]]:
"""
SinglePartitionRouter doesn't have parent streams
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class SubstreamPartitionRouter(PartitionRouter):
Attributes:
parent_stream_configs (List[ParentStreamConfig]): parent streams to iterate over and their config
State Management:
- The state for each parent stream is stored in `_parent_state`, which is a dictionary mapping
stream names to their current states.
- The initial state is stored in `_initial_parent_state`, set during the first assignment via `set_initial_state`.
- The `_parent_state_to_partition` dictionary maps partition keys to parent states,
facilitating state retrieval based on partitions.
"""

parent_stream_configs: List[ParentStreamConfig]
Expand All @@ -69,16 +76,12 @@ class SubstreamPartitionRouter(PartitionRouter):
MAX_PARTITIONS = 2 # Limit for the number of partitions
# Currently, there is a limitation of two partitions due to the logic of the global cursor,
# which identifies what slice is last and stores one slice in memory. Once substreams are added to concurrent CDK,
# we can expand this limit and update the logic for deleting processed partitions.
# we can increase this limit and update the logic for deleting processed partitions.

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if not self.parent_stream_configs:
raise ValueError("SubstreamPartitionRouter needs at least 1 parent stream")
self._parameters = parameters
self._initial_parent_state: Dict[str, Any] = {}
self._parent_state: Dict[str, Any] = {}
self._partition_serializer = PerPartitionKeySerializer()
self._parent_state_to_partition: OrderedDict[str, Dict[str, Any]] = OrderedDict()

def get_request_params(
self,
Expand Down Expand Up @@ -153,8 +156,6 @@ def stream_slices(self) -> Iterable[StreamSlice]:
if parent_stream_config.extra_fields:
extra_fields = [[field_path_part.eval(self.config) for field_path_part in field_path] for field_path in parent_stream_config.extra_fields] # type: ignore # extra_fields is always casted to an interpolated string

incremental_dependency = parent_stream_config.incremental_dependency

# read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does
# not support either substreams or RFR, but something that needs to be considered once we do
for parent_record in parent_stream.read_only_records():
Expand Down Expand Up @@ -188,26 +189,8 @@ def stream_slices(self) -> Iterable[StreamSlice]:
extra_fields=extracted_extra_fields,
)

if incremental_dependency:
partition_key = self._partition_serializer.to_partition_key(stream_slice.partition)

# Limit the number of states to two and remove the oldest if necessary
if len(self._parent_state_to_partition) >= self.MAX_PARTITIONS:
self._parent_state_to_partition.popitem(last=False) # Remove the oldest entry

self._parent_state_to_partition[partition_key] = copy.deepcopy(self._parent_state)

yield stream_slice

if incremental_dependency:
parent_state = copy.deepcopy(parent_stream.state)
self._parent_state[parent_stream.name] = parent_state

# A final parent state update and yield of records is needed, so we don't skip records for the final parent slice
if incremental_dependency:
parent_state = copy.deepcopy(parent_stream.state)
self._parent_state[parent_stream.name] = parent_state

def _extract_extra_fields(
self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None
) -> Mapping[str, Any]:
Expand Down Expand Up @@ -264,10 +247,8 @@ def set_initial_state(self, stream_state: StreamState) -> None:
for parent_config in self.parent_stream_configs:
if parent_config.incremental_dependency:
parent_config.stream.state = parent_state.get(parent_config.stream.name, {})
self._initial_parent_state[parent_config.stream.name] = parent_config.stream.state
self._parent_state = copy.deepcopy(self._initial_parent_state)

def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last: bool = True) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, last: bool = True) -> Optional[Mapping[str, StreamState]]:
"""
Get the state of the parent streams.
Expand All @@ -284,14 +265,11 @@ def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None, last:
}
}
"""
if not partition:
if last:
return copy.deepcopy(self._parent_state)
else:
return copy.deepcopy(self._initial_parent_state)
else:
partition_key = self._partition_serializer.to_partition_key(partition)
return self._parent_state_to_partition.get(partition_key)
parent_state = {}
for parent_config in self.parent_stream_configs:
if parent_config.incremental_dependency:
parent_state[parent_config.stream.name] = copy.deepcopy(parent_config.stream.state)
return parent_state

@property
def logger(self) -> logging.Logger:
Expand Down

0 comments on commit 2b38e0e

Please sign in to comment.