Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(airbyte-cdk): Add Per Partition with Global fallback Cursor #45125

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from datetime import timedelta
from typing import Optional

from airbyte_cdk import StreamSlice
from airbyte_cdk.sources.declarative.async_job.timer import Timer
from airbyte_cdk.sources.types import StreamSlice
maxi297 marked this conversation as resolved.
Show resolved Hide resolved

from .status import AsyncJobStatus

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from abc import abstractmethod
from typing import Any, Iterable, Mapping, Set

from airbyte_cdk import StreamSlice
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
from airbyte_cdk.sources.types import StreamSlice


class AsyncJobRepository:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#
import datetime
from dataclasses import InitVar, dataclass
from typing import Any, Iterable, Mapping, Optional
from typing import Any, Iterable, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor, PerPartitionCursor
from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor, GlobalSubstreamCursor, PerPartitionWithGlobalCursor
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState

Expand Down Expand Up @@ -50,14 +50,12 @@ class ClientSideIncrementalRecordFilterDecorator(RecordFilter):
def __init__(
self,
date_time_based_cursor: DatetimeBasedCursor,
per_partition_cursor: Optional[PerPartitionCursor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably outside of the scope of this PR but could we eventually have just one cursor as a parameter here? I'm trying to understand why we need both cursor and it seems like we could just have one of the interfaice Cursor and the filtering code would look like:

    def filter_records(
        self,
        records: Iterable[Mapping[str, Any]],
        stream_state: StreamState,
        stream_slice: Optional[StreamSlice] = None,
        next_page_token: Optional[Mapping[str, Any]] = None,
    ) -> Iterable[Mapping[str, Any]]:
        records = (
            record
            for record in records
            if self._cursor.should_be_synced(record)
        )
        if self.condition:
            records = super().filter_records(
                records=records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
            )
        yield from records

If we agree that this is a path forward, I'll create an issue for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not possible now. The issue is that _substream_cursor doesn't have methods to work with the cursor, for example: select_best_end_datetime, parse_date.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand: if we use if self._cursor.should_be_synced(record), select_best_end_datetime and parse_date can be private, right?

is_global_substream_cursor: bool = False,
substream_cursor: Optional[Union[PerPartitionWithGlobalCursor, GlobalSubstreamCursor]],
**kwargs: Any,
):
super().__init__(**kwargs)
self._date_time_based_cursor = date_time_based_cursor
self._per_partition_cursor = per_partition_cursor
self.is_global_substream_cursor = is_global_substream_cursor
self._substream_cursor = substream_cursor

@property
def _cursor_field(self) -> str:
Expand Down Expand Up @@ -103,15 +101,9 @@ def _get_state_value(self, stream_state: StreamState, stream_slice: StreamSlice)
:param StreamSlice stream_slice: Current Stream slice
:return Optional[str]: cursor_value in case it was found, otherwise None.
"""
if self._per_partition_cursor:
# self._per_partition_cursor is the same object that DeclarativeStream uses to save/update stream_state
partition_state = self._per_partition_cursor.select_state(stream_slice=stream_slice)
return partition_state.get(self._cursor_field) if partition_state else None
state = (self._substream_cursor or self._date_time_based_cursor).select_state(stream_slice)

if self.is_global_substream_cursor:
return stream_state.get("state", {}).get(self._cursor_field) # type: ignore # state is inside a dict for GlobalSubstreamCursor

return stream_state.get(self._cursor_field)
return state.get(self._cursor_field) if state else None

def _get_filter_date(self, state_value: Optional[str]) -> datetime.datetime:
start_date_parsed = self._start_date_from_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import GlobalSubstreamCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import CursorFactory, PerPartitionCursor
from airbyte_cdk.sources.declarative.incremental.resumable_full_refresh_cursor import ResumableFullRefreshCursor, ChildPartitionResumableFullRefreshCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import PerPartitionWithGlobalCursor
from airbyte_cdk.sources.declarative.incremental.resumable_full_refresh_cursor import (
ChildPartitionResumableFullRefreshCursor,
ResumableFullRefreshCursor,
)

__all__ = [
"CursorFactory",
"DatetimeBasedCursor",
"DeclarativeCursor",
"GlobalSubstreamCursor",
"PerPartitionCursor",
"PerPartitionWithGlobalCursor",
"ResumableFullRefreshCursor",
"ChildPartitionResumableFullRefreshCursor"
"ChildPartitionResumableFullRefreshCursor",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,33 @@

import threading
import time
from typing import Any, Iterable, Mapping, Optional, Union
from typing import Any, Iterable, Mapping, Optional, TypeVar, Union

from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState

T = TypeVar("T")


def iterate_with_last_flag(generator: Iterable[T]) -> Iterable[tuple[T, bool]]:
"""
Iterates over the given generator and returns a tuple containing the element and a flag
indicating whether it's the last element in the generator. If the generator is empty,
it returns an empty iterator.
"""
iterator = iter(generator)
try:
current = next(iterator)
except StopIteration:
return # Return an empty iterator

for next_item in iterator:
yield current, False
current = next_item
yield current, True


class Timer:
"""
Expand All @@ -25,7 +45,7 @@ def start(self) -> None:

def finish(self) -> int:
if self._start:
return int((time.perf_counter_ns() - self._start) // 1e9)
return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__()
else:
raise RuntimeError("Global substream cursor timer not started")

Expand Down Expand Up @@ -53,6 +73,9 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti
self._all_slices_yielded = False
self._lookback_window: Optional[int] = None

def start_slices_generation(self) -> None:
self._timer.start()

def stream_slices(self) -> Iterable[StreamSlice]:
"""
Generates stream slices, ensuring the last slice is properly flagged and processed.
Expand All @@ -68,32 +91,37 @@ def stream_slices(self) -> Iterable[StreamSlice]:
* Setting `self._all_slices_yielded = True`. We do that before actually yielding the last slice as the caller of `stream_slices` might stop iterating at any point and hence the code after `yield` might not be executed
* Yield the last slice. At that point, once there are as many slices yielded as closes, the global slice will be closed too
"""
previous_slice = None

slice_generator = (
StreamSlice(partition=partition, cursor_slice=cursor_slice)
for partition in self._partition_router.stream_slices()
for cursor_slice in self._stream_cursor.stream_slices()
)
self._timer.start()

for slice in slice_generator:
if previous_slice is not None:
# Release the semaphore to indicate that a slice has been yielded
self._slice_semaphore.release()
yield previous_slice
self.start_slices_generation()
for slice, last in iterate_with_last_flag(slice_generator):
self.register_slice(last)
yield slice

# Store the current slice as the previous slice for the next iteration
previous_slice = slice
def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
slice_generator = (
StreamSlice(partition=partition, cursor_slice=cursor_slice) for cursor_slice in self._stream_cursor.stream_slices()
)

# After all slices have been generated, release the semaphore one final time
# and flag that all slices have been yielded
self._slice_semaphore.release()
self._all_slices_yielded = True
yield from slice_generator

def register_slice(self, last: bool) -> None:
"""
Tracks the processing of a stream slice.

# Yield the last slice
if previous_slice is not None:
yield previous_slice
Releases the semaphore for each slice. If it's the last slice (`last=True`),
sets `_all_slices_yielded` to `True` to indicate no more slices will be processed.

Args:
last (bool): True if the current slice is the last in the sequence.
"""
self._slice_semaphore.release()
if last:
self._all_slices_yielded = True

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand Down Expand Up @@ -125,7 +153,12 @@ def set_initial_state(self, stream_state: StreamState) -> None:
self._lookback_window = stream_state["lookback_window"]
self._inject_lookback_into_stream_cursor(stream_state["lookback_window"])

self._stream_cursor.set_initial_state(stream_state["state"])
if "state" in stream_state:
self._stream_cursor.set_initial_state(stream_state["state"])
elif "states" not in stream_state:
# We assume that `stream_state` is in the old global format
# Example: {"global_state_format_key": "global_state_format_value"}
self._stream_cursor.set_initial_state(stream_state)

# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState

logger = logging.getLogger("airbyte")


class CursorFactory:
def __init__(self, create_function: Callable[[], DeclarativeCursor]):
Expand All @@ -22,23 +24,20 @@ def create(self) -> DeclarativeCursor:

class PerPartitionCursor(DeclarativeCursor):
"""
Given a stream has many partitions, it is important to provide a state per partition.

Record | Stream Slice | Last Record | DatetimeCursorBased cursor
-- | -- | -- | --
1 | {"start_time": "2021-01-01","end_time": "2021-01-31","owner_resource": "1"''} | cursor_field: “2021-01-15” | 2021-01-15
2 | {"start_time": "2021-02-01","end_time": "2021-02-28","owner_resource": "1"''} | cursor_field: “2021-02-15” | 2021-02-15
3 | {"start_time": "2021-01-01","end_time": "2021-01-31","owner_resource": "2"''} | cursor_field: “2021-01-03” | 2021-01-03
4 | {"start_time": "2021-02-01","end_time": "2021-02-28","owner_resource": "2"''} | cursor_field: “2021-02-14” | 2021-02-14

Given the following errors, this can lead to some loss or duplication of records:
When | Problem | Affected Record
-- | -- | --
Between record #1 and #2 | Loss | #3
Between record #2 and #3 | Loss | #3, #4
Between record #3 and #4 | Duplication | #1, #2

Therefore, we need to manage state per partition.
Manages state per partition when a stream has many partitions, to prevent data loss or duplication.

**Partition Limitation and Limit Reached Logic**

- **DEFAULT_MAX_PARTITIONS_NUMBER**: The maximum number of partitions to keep in memory (default is 10,000).
- **_cursor_per_partition**: An ordered dictionary that stores cursors for each partition.
- **_over_limit**: A counter that increments each time an oldest partition is removed when the limit is exceeded.

The class ensures that the number of partitions tracked does not exceed the `DEFAULT_MAX_PARTITIONS_NUMBER` to prevent excessive memory usage.

- When the number of partitions exceeds the limit, the oldest partitions are removed from `_cursor_per_partition`, and `_over_limit` is incremented accordingly.
- The `limit_reached` method returns `True` when `_over_limit` exceeds `DEFAULT_MAX_PARTITIONS_NUMBER`, indicating that the global cursor should be used instead of per-partition cursors.

This approach avoids unnecessary switching to a global cursor due to temporary spikes in partition counts, ensuring that switching is only done when a sustained high number of partitions is observed.
"""

DEFAULT_MAX_PARTITIONS_NUMBER = 10000
Expand All @@ -54,30 +53,40 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
# The dict is ordered to ensure that once the maximum number of partitions is reached,
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
for partition in slices:
# Ensure the maximum number of partitions is not exceeded
self._ensure_partition_limit()
yield from self.generate_slices_from_partition(partition)

def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
# Ensure the maximum number of partitions is not exceeded
self._ensure_partition_limit()

cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor

for cursor_slice in cursor.stream_slices():
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
for cursor_slice in cursor.stream_slices():
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)

def _ensure_partition_limit(self) -> None:
"""
Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.
"""
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
oldest_partition = self._cursor_per_partition.popitem(last=False)[0] # Remove the oldest partition
logging.warning(f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}.")
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
)

def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be self._over_limit? From my understanding, we are incrementing this value while we're streaming the slices. And we only start incrementing the moment the number of cursors exceeds 10,000.

So unless I'm misunderstanding, if we have hypothetically 10,500 cursors, max of 10,000, then the final resulting value of self._over_limit = 500.

And then in limit_reached() we compare 500 > self.DEFAULT_MAX_PARTITIONS_NUMBER which will be false. So in order for this to return true we would need 20,000 partitions.

My question is why do we use _over_limit instead of len(self._cursor_per_partition) when we call limit_reached() from the PerPartitionWithGlobalCursor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right—we increment self._over_limit each time we drop an old partition after exceeding our DEFAULT_MAX_PARTITIONS_NUMBER (10,000). So if we have 10,500 partitions, we'll drop 500 oldest ones, and self._over_limit becomes 500.

We compare self._over_limit to DEFAULT_MAX_PARTITIONS_NUMBER in limit_reached(). We only switch to a global cursor when self._over_limit exceeds 10,000 (i.e., we've had to drop over 10,000 partitions). This way, we avoid switching too early due to temporary spikes in partition counts.

Using len(self._cursor_per_partition) wouldn't help here because we've capped it at 10,000—it'll never exceed the limit since we keep removing old partitions to stay within that number.

I updated the doc with this explanation.

Copy link
Contributor

@brianjlai brianjlai Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I think it wasn't clear to me that the intent was to switch to global cursor state once we've dropped X number of records (in this case 10,000). I thought the intent was soley based on once we exceeded 10,000 records to return True.

Thanks for updating the docs, but lets also add a small docstring here too. Something along the lines of saying that this method returns true after the number of dropped partitions from state exceeds the default max partitions. This is used to prevent against spikes in partition counts


def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand Down Expand Up @@ -121,6 +130,10 @@ def set_initial_state(self, stream_state: StreamState) -> None:
for state in stream_state["states"]:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"])

# set default state for missing partitions if it is per partition with fallback to global
if "state" in stream_state:
self._state_to_migrate_from = stream_state["state"]

# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)

Expand Down
Loading
Loading