Skip to content

Commit

Permalink
Add PerPartitionWithGlobalCursor
Browse files Browse the repository at this point in the history
  • Loading branch information
tolik0 committed Sep 6, 2024
1 parent 950ea06 commit 8daaacb
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
# The dict is ordered to ensure that once the maximum number of partitions is reached,
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()

def stream_slices(self) -> Iterable[StreamSlice]:
Expand All @@ -76,8 +78,12 @@ def _ensure_partition_limit(self) -> None:
Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.
"""
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
oldest_partition = self._cursor_per_partition.popitem(last=False)[0] # Remove the oldest partition
logging.warning(f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}.")
logging.warning(f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}.")

def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Iterable, Mapping, Optional, Union

from airbyte_cdk import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, CursorFactory

from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState

class PerPartitionWithGlobalCursor(DeclarativeCursor):
"""
Manages state for streams with multiple partitions, with an optional fallback to a global cursor when specific conditions are met.
This cursor is designed to handle cases where a stream is partitioned, allowing state management per partition. However, if a certain condition is met (e.g., the number of records in a partition exceeds 5 times a defined limit), the cursor will fallback to a global state.
## Overview
Given a stream with many partitions, it is crucial to maintain a state per partition to avoid data loss or duplication. This class provides a mechanism to handle such cases and ensures that the stream's state is accurately maintained.
## State Management
- **Partition-Based State**: Manages state individually for each partition, ensuring that each partition's data is processed correctly and independently.
- **Global Fallback**: Switches to a global cursor when a predefined condition is met (e.g., the number of records in a partition exceeds a certain threshold). This ensures that the system can handle cases where partition-based state management is no longer efficient or viable.
## Example State Structure
```json
{
"states": [
{"partition_key": "partition_1": "cursor_field": "2021-01-15"},
{"partition_key": "partition_2": "cursor_field": "2021-02-14"}
[,
"state": {
"cursor_field": "2021-02-15"
},
"use_global_cursor": false
}
```
"""

def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter, stream_cursor: DatetimeBasedCursor):
self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router)
self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router)
self._use_global_cursor = False

def stream_slices(self) -> Iterable[StreamSlice]:
if self._use_global_cursor:
yield from self._global_cursor.stream_slices()
else:
slice_generator = (slice for slice in self._per_partition_cursor.stream_slices())
yield from self._global_cursor.generate_slices_from_generator(slice_generator)

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the initial state for the cursors.
"""
self._use_global_cursor = stream_state.get("use_global_cursor", False)

self._global_cursor.set_initial_state(stream_state)
self._per_partition_cursor.set_initial_state(stream_state)

def observe(self, stream_slice: StreamSlice, record: Record) -> None:
if not self._use_global_cursor and self._per_partition_cursor.limit_reached():
self._use_global_cursor = True

if not self._use_global_cursor:
self._per_partition_cursor.observe(stream_slice, record)
self._global_cursor.observe(stream_slice, record)

def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
self._per_partition_cursor.close_slice(stream_slice, *args)
self._global_cursor.close_slice(stream_slice, *args)

def get_stream_state(self) -> StreamState:
final_state = {"use_global_cursor": self._use_global_cursor}

final_state.update(self._global_cursor.get_stream_state())
if not self._use_global_cursor:
final_state.update(self._per_partition_cursor.get_stream_state())

return final_state

def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
if self._use_global_cursor:
return self._global_cursor.select_state(stream_slice)
else:
return self._per_partition_cursor.select_state(stream_slice)

def get_request_params(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if self._use_global_cursor:
return self._global_cursor.get_request_params(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
else:
return self._per_partition_cursor.get_request_params(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)

def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if self._use_global_cursor:
return self._global_cursor.get_request_headers(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
else:
return self._per_partition_cursor.get_request_headers(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)

def get_request_body_data(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Union[Mapping[str, Any], str]:
if self._use_global_cursor:
return self._global_cursor.get_request_body_data(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
else:
return self._per_partition_cursor.get_request_body_data(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)


def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if self._use_global_cursor:
return self._global_cursor.get_request_body_json(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)
else:
return self._per_partition_cursor.get_request_body_json(
stream_state=stream_state,
stream_slice=stream_slice,
next_page_token=next_page_token,
)

def should_be_synced(self, record: Record) -> bool:
return self._global_cursor.should_be_synced(record) or self._per_partition_cursor.should_be_synced(record)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
return self._global_cursor.is_greater_than_or_equal(first, second)
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PerPartitionCursor,
ResumableFullRefreshCursor,
)
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import PerPartitionWithGlobalCursor
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.migrations.legacy_to_per_partition_state_migration import LegacyToPerPartitionStateMigration
Expand Down Expand Up @@ -630,7 +631,7 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi
and hasattr(model.incremental_sync, "is_client_side_incremental")
and model.incremental_sync.is_client_side_incremental
):
supported_slicers = (DatetimeBasedCursor, GlobalSubstreamCursor, PerPartitionCursor)
supported_slicers = (DatetimeBasedCursor, GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor)
if combined_slicers and not isinstance(combined_slicers, supported_slicers):
raise ValueError("Unsupported Slicer is used. PerPartitionCursor should be used here instead")
client_side_incremental_sync = {
Expand Down Expand Up @@ -703,12 +704,20 @@ def _merge_stream_slicers(self, model: DeclarativeStreamModel, config: Config) -
cursor_component = self._create_component_from_model(model=incremental_sync_model, config=config)
return GlobalSubstreamCursor(stream_cursor=cursor_component, partition_router=stream_slicer)
else:
return PerPartitionCursor(
cursor_component = self._create_component_from_model(model=incremental_sync_model, config=config)
return PerPartitionWithGlobalCursor(
cursor_factory=CursorFactory(
lambda: self._create_component_from_model(model=incremental_sync_model, config=config),
),
partition_router=stream_slicer,
stream_cursor=cursor_component,
)
# return PerPartitionCursor(
# cursor_factory=CursorFactory(
# lambda: self._create_component_from_model(model=incremental_sync_model, config=config),
# ),
# partition_router=stream_slicer,
# )
elif model.incremental_sync:
return self._create_component_from_model(model=model.incremental_sync, config=config) if model.incremental_sync else None
elif stream_slicer:
Expand Down

0 comments on commit 8daaacb

Please sign in to comment.