diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py index d30d833c8d5b..09ce080c8ae4 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py @@ -6,7 +6,7 @@ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.models import SyncMode -from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor +from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever @@ -200,7 +200,7 @@ def _get_checkpoint_reader( cursor = self.get_cursor() checkpoint_mode = self._checkpoint_mode - if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor)): + if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor)): self.has_multiple_slices = True return CursorBasedCheckpointReader( stream_slices=mappings_or_slices, diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index e0e7654fe476..f0b8da5bac8c 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -202,10 +202,10 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: self._lookback_window = self._timer.finish() self._stream_cursor.close_slice(StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args) - def get_stream_state(self) -> StreamState: + def get_stream_state(self, partition: Optional[StreamSlice] = None) -> StreamState: state: dict[str, Any] = {"state": self._stream_cursor.get_stream_state()} - parent_state = self._partition_router.get_stream_state() + parent_state = self._partition_router.get_stream_state(partition=partition) if parent_state: state["parent_state"] = parent_state diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index 86236ec92230..a01ac39a381e 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -55,6 +55,7 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict() self._over_limit = 0 self._partition_serializer = PerPartitionKeySerializer() + self._current_partition = None def stream_slices(self) -> Iterable[StreamSlice]: slices = self._partition_router.stream_slices() @@ -153,7 +154,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: f"we should only update state for partitions that were emitted during `stream_slices`" ) - def get_stream_state(self) -> StreamState: + def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None) -> StreamState: states = [] for partition_tuple, cursor in self._cursor_per_partition.items(): cursor_state = cursor.get_stream_state() @@ -166,7 +167,7 @@ def get_stream_state(self) -> StreamState: ) state: dict[str, Any] = {"states": states} - parent_state = self._partition_router.get_stream_state() + parent_state = self._partition_router.get_stream_state(partition=partition) if parent_state: state["parent_state"] = parent_state return state diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py index bc169263735f..e0014d941d1f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py @@ -1,7 +1,6 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - from typing import Any, Iterable, Mapping, Optional, Union from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor @@ -66,6 +65,7 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router) self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router) self._use_global_cursor = False + self._current_partition = None def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]: return self._global_cursor if self._use_global_cursor else self._per_partition_cursor @@ -76,11 +76,13 @@ def stream_slices(self) -> Iterable[StreamSlice]: # Iterate through partitions and process slices for partition, is_last_partition in iterate_with_last_flag(self._partition_router.stream_slices()): # Generate slices for the current cursor and handle the last slice using the flag + self._current_partition = partition.partition for slice, is_last_slice in iterate_with_last_flag( self._get_active_cursor().generate_slices_from_partition(partition=partition) ): self._global_cursor.register_slice(is_last_slice and is_last_partition) yield slice + self._current_partition = None def set_initial_state(self, stream_state: StreamState) -> None: """ @@ -107,9 +109,9 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: def get_stream_state(self) -> StreamState: final_state = {"use_global_cursor": self._use_global_cursor} - final_state.update(self._global_cursor.get_stream_state()) + final_state.update(self._global_cursor.get_stream_state(partition=self._current_partition)) if not self._use_global_cursor: - final_state.update(self._per_partition_cursor.get_stream_state()) + final_state.update(self._per_partition_cursor.get_stream_state(partition=self._current_partition)) return final_state diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 9eac1f6bb66e..815aa7aff215 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -3,6 +3,7 @@ # import copy import logging +from collections import OrderedDict from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Union @@ -12,6 +13,7 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException @@ -70,6 +72,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: raise ValueError("SubstreamPartitionRouter needs at least 1 parent stream") self._parameters = parameters self._parent_state: Dict[str, Any] = {} + self._partition_serializer = PerPartitionKeySerializer() + self._parent_state_to_partition: Dict[str, Dict[str, Any]] = OrderedDict() def get_request_params( self, @@ -173,18 +177,26 @@ def stream_slices(self) -> Iterable[StreamSlice]: # Add extra fields extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields) - yield StreamSlice( + stream_slice = StreamSlice( partition={partition_field: partition_value, "parent_slice": parent_partition or {}}, cursor_slice={}, extra_fields=extracted_extra_fields, ) if incremental_dependency: - self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state) + partition_key = self._partition_serializer.to_partition_key(stream_slice.partition) + self._parent_state_to_partition[partition_key] = copy.deepcopy(self._parent_state) + + yield stream_slice + + if incremental_dependency: + parent_state = copy.deepcopy(parent_stream.state) + self._parent_state[parent_stream.name] = parent_state # A final parent state update and yield of records is needed, so we don't skip records for the final parent slice if incremental_dependency: - self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state) + parent_state = copy.deepcopy(parent_stream.state) + self._parent_state[parent_stream.name] = parent_state def _extract_extra_fields( self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None @@ -244,7 +256,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: parent_config.stream.state = parent_state.get(parent_config.stream.name, {}) self._parent_state[parent_config.stream.name] = parent_config.stream.state - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None) -> Optional[Mapping[str, StreamState]]: """ Get the state of the parent streams. @@ -261,7 +273,11 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: } } """ - return copy.deepcopy(self._parent_state) + if not partition: + return copy.deepcopy(self._parent_state) + else: + partition_key = self._partition_serializer.to_partition_key(partition) + return self._parent_state_to_partition.get(partition_key) @property def logger(self) -> logging.Logger: diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py index 488ab737db4e..b8809bbd93f2 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_parent_state_stream.py @@ -350,6 +350,13 @@ def _run_read( "votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}], }, ), + # Fetch votes for comment 12 of post 1 + ( + "https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-15T00:00:00Z", + { + "votes": [], + }, + ), # Fetch votes for comment 20 of post 2 ( "https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-12T00:00:00Z", @@ -453,10 +460,44 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r for url, response in mock_requests: m.get(url, json=response) + # Run the initial read output = _run_read(manifest, config, _stream_name, initial_state) output_data = [message.record.data for message in output if message.record] + # Assert that output_data equals expected_records assert output_data == expected_records + + # Collect the intermediate states and records produced before each state + cumulative_records = [] + intermediate_states = [] + for message in output: + if message.type.value == "RECORD": + record_data = message.record.data + cumulative_records.append(record_data) + elif message.type.value == "STATE": + # Record the state and the records produced before this state + state = message.state + records_before_state = cumulative_records.copy() + intermediate_states.append((state, records_before_state)) + + # For each intermediate state, perform another read starting from that state + for state, records_before_state in intermediate_states[:-1]: + output_intermediate = _run_read(manifest, config, _stream_name, [state]) + records_from_state = [message.record.data for message in output_intermediate if message.record] + + # Combine records produced before the state with records from the new read + cumulative_records_state = records_before_state + records_from_state + + # Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync. + cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values()) + + # Compare the cumulative records with the expected records + expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values()) + assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted( + expected_records_set, key=lambda x: orjson.dumps(x) + ), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}" + + # Assert that the final state matches the expected state final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] assert final_state[-1] == expected_state @@ -690,14 +731,14 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, ( "https://api.example.com/community/posts/2/comments?per_page=100", { - "comments": [], + "comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}], "next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2", }, ), # Fetch the second page of comments for post 2 ( "https://api.example.com/community/posts/2/comments?per_page=100&page=2", - {"comments": []}, + {"comments": [{"id": 21, "post_id": 2, "updated_at": "2024-01-21T00:00:00Z"}]}, ), # Fetch the first page of votes for comment 20 of post 2 ( @@ -712,7 +753,7 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests, # Fetch the first page of comments for post 3 ( "https://api.example.com/community/posts/3/comments?per_page=100", - {"comments": []}, + {"comments": [{"id": 30, "post_id": 3, "updated_at": "2024-01-09T00:00:00Z"}]}, ), # Fetch the first page of votes for comment 30 of post 3 ( @@ -789,44 +830,10 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests, for url, response in mock_requests: m.get(url, json=response) - # Run the initial read output = _run_read(manifest, config, _stream_name, initial_state) output_data = [message.record.data for message in output if message.record] - # Assert that output_data equals expected_records assert output_data == expected_records - - # Collect the intermediate states and records produced before each state - cumulative_records = [] - intermediate_states = [] - for message in output: - if message.type.value == "RECORD": - record_data = message.record.data - cumulative_records.append(record_data) - elif message.type.value == "STATE": - # Record the state and the records produced before this state - state = message.state - records_before_state = cumulative_records.copy() - intermediate_states.append((state, records_before_state)) - - # For each intermediate state, perform another read starting from that state - for state, records_before_state in intermediate_states[:-1]: - output_intermediate = _run_read(manifest, config, _stream_name, [state]) - records_from_state = [message.record.data for message in output_intermediate if message.record] - - # Combine records produced before the state with records from the new read - cumulative_records_state = records_before_state + records_from_state - - # Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync. - cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values()) - - # Compare the cumulative records with the expected records - expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values()) - assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted( - expected_records_set, key=lambda x: orjson.dumps(x) - ), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}" - - # Assert that the final state matches the expected state final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state] assert final_state[-1] == expected_state