Skip to content

Commit

Permalink
Align with new SubstreamPartitionRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
tolik0 committed Oct 22, 2024
1 parent c998f5c commit 58c96f2
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor
from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
Expand Down Expand Up @@ -200,7 +200,7 @@ def _get_checkpoint_reader(
cursor = self.get_cursor()
checkpoint_mode = self._checkpoint_mode

if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor)):
if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor)):
self.has_multiple_slices = True
return CursorBasedCheckpointReader(
stream_slices=mappings_or_slices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
self._lookback_window = self._timer.finish()
self._stream_cursor.close_slice(StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args)

def get_stream_state(self) -> StreamState:
def get_stream_state(self, partition: Optional[StreamSlice] = None) -> StreamState:
state: dict[str, Any] = {"state": self._stream_cursor.get_stream_state()}

parent_state = self._partition_router.get_stream_state()
parent_state = self._partition_router.get_stream_state(partition=partition)
if parent_state:
state["parent_state"] = parent_state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()
self._current_partition = None

def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
Expand Down Expand Up @@ -153,7 +154,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
f"we should only update state for partitions that were emitted during `stream_slices`"
)

def get_stream_state(self) -> StreamState:
def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None) -> StreamState:
states = []
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor.get_stream_state()
Expand All @@ -166,7 +167,7 @@ def get_stream_state(self) -> StreamState:
)
state: dict[str, Any] = {"states": states}

parent_state = self._partition_router.get_stream_state()
parent_state = self._partition_router.get_stream_state(partition=partition)
if parent_state:
state["parent_state"] = parent_state
return state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Iterable, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
Expand Down Expand Up @@ -66,6 +65,7 @@ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRou
self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router)
self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router)
self._use_global_cursor = False
self._current_partition = None

def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]:
return self._global_cursor if self._use_global_cursor else self._per_partition_cursor
Expand All @@ -76,11 +76,13 @@ def stream_slices(self) -> Iterable[StreamSlice]:
# Iterate through partitions and process slices
for partition, is_last_partition in iterate_with_last_flag(self._partition_router.stream_slices()):
# Generate slices for the current cursor and handle the last slice using the flag
self._current_partition = partition.partition
for slice, is_last_slice in iterate_with_last_flag(
self._get_active_cursor().generate_slices_from_partition(partition=partition)
):
self._global_cursor.register_slice(is_last_slice and is_last_partition)
yield slice
self._current_partition = None

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand All @@ -107,9 +109,9 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
def get_stream_state(self) -> StreamState:
final_state = {"use_global_cursor": self._use_global_cursor}

final_state.update(self._global_cursor.get_stream_state())
final_state.update(self._global_cursor.get_stream_state(partition=self._current_partition))
if not self._use_global_cursor:
final_state.update(self._per_partition_cursor.get_stream_state())
final_state.update(self._per_partition_cursor.get_stream_state(partition=self._current_partition))

return final_state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import copy
import logging
from collections import OrderedDict
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Union

Expand All @@ -12,6 +13,7 @@
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException

Expand Down Expand Up @@ -70,6 +72,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
raise ValueError("SubstreamPartitionRouter needs at least 1 parent stream")
self._parameters = parameters
self._parent_state: Dict[str, Any] = {}
self._partition_serializer = PerPartitionKeySerializer()
self._parent_state_to_partition: Dict[str, Dict[str, Any]] = OrderedDict()

def get_request_params(
self,
Expand Down Expand Up @@ -173,18 +177,26 @@ def stream_slices(self) -> Iterable[StreamSlice]:
# Add extra fields
extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields)

yield StreamSlice(
stream_slice = StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition or {}},
cursor_slice={},
extra_fields=extracted_extra_fields,
)

if incremental_dependency:
self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state)
partition_key = self._partition_serializer.to_partition_key(stream_slice.partition)
self._parent_state_to_partition[partition_key] = copy.deepcopy(self._parent_state)

yield stream_slice

if incremental_dependency:
parent_state = copy.deepcopy(parent_stream.state)
self._parent_state[parent_stream.name] = parent_state

# A final parent state update and yield of records is needed, so we don't skip records for the final parent slice
if incremental_dependency:
self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state)
parent_state = copy.deepcopy(parent_stream.state)
self._parent_state[parent_stream.name] = parent_state

def _extract_extra_fields(
self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None
Expand Down Expand Up @@ -244,7 +256,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
parent_config.stream.state = parent_state.get(parent_config.stream.name, {})
self._parent_state[parent_config.stream.name] = parent_config.stream.state

def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
def get_stream_state(self, partition: Optional[Mapping[str, Any]] = None) -> Optional[Mapping[str, StreamState]]:
"""
Get the state of the parent streams.
Expand All @@ -261,7 +273,11 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
}
}
"""
return copy.deepcopy(self._parent_state)
if not partition:
return copy.deepcopy(self._parent_state)
else:
partition_key = self._partition_serializer.to_partition_key(partition)
return self._parent_state_to_partition.get(partition_key)

@property
def logger(self) -> logging.Logger:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ def _run_read(
"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}],
},
),
# Fetch votes for comment 12 of post 1
(
"https://api.example.com/community/posts/1/comments/12/votes?per_page=100&start_time=2024-01-15T00:00:00Z",
{
"votes": [],
},
),
# Fetch votes for comment 20 of post 2
(
"https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-12T00:00:00Z",
Expand Down Expand Up @@ -453,10 +460,44 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r
for url, response in mock_requests:
m.get(url, json=response)

# Run the initial read
output = _run_read(manifest, config, _stream_name, initial_state)
output_data = [message.record.data for message in output if message.record]

# Assert that output_data equals expected_records
assert output_data == expected_records

# Collect the intermediate states and records produced before each state
cumulative_records = []
intermediate_states = []
for message in output:
if message.type.value == "RECORD":
record_data = message.record.data
cumulative_records.append(record_data)
elif message.type.value == "STATE":
# Record the state and the records produced before this state
state = message.state
records_before_state = cumulative_records.copy()
intermediate_states.append((state, records_before_state))

# For each intermediate state, perform another read starting from that state
for state, records_before_state in intermediate_states[:-1]:
output_intermediate = _run_read(manifest, config, _stream_name, [state])
records_from_state = [message.record.data for message in output_intermediate if message.record]

# Combine records produced before the state with records from the new read
cumulative_records_state = records_before_state + records_from_state

# Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync.
cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values())

# Compare the cumulative records with the expected records
expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values())
assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted(
expected_records_set, key=lambda x: orjson.dumps(x)
), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}"

# Assert that the final state matches the expected state
final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state]
assert final_state[-1] == expected_state

Expand Down Expand Up @@ -690,14 +731,14 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests,
(
"https://api.example.com/community/posts/2/comments?per_page=100",
{
"comments": [],
"comments": [{"id": 20, "post_id": 2, "updated_at": "2024-01-22T00:00:00Z"}],
"next_page": "https://api.example.com/community/posts/2/comments?per_page=100&page=2",
},
),
# Fetch the second page of comments for post 2
(
"https://api.example.com/community/posts/2/comments?per_page=100&page=2",
{"comments": []},
{"comments": [{"id": 21, "post_id": 2, "updated_at": "2024-01-21T00:00:00Z"}]},
),
# Fetch the first page of votes for comment 20 of post 2
(
Expand All @@ -712,7 +753,7 @@ def test_incremental_parent_state_no_slices(test_name, manifest, mock_requests,
# Fetch the first page of comments for post 3
(
"https://api.example.com/community/posts/3/comments?per_page=100",
{"comments": []},
{"comments": [{"id": 30, "post_id": 3, "updated_at": "2024-01-09T00:00:00Z"}]},
),
# Fetch the first page of votes for comment 30 of post 3
(
Expand Down Expand Up @@ -789,44 +830,10 @@ def test_incremental_parent_state_no_records(test_name, manifest, mock_requests,
for url, response in mock_requests:
m.get(url, json=response)

# Run the initial read
output = _run_read(manifest, config, _stream_name, initial_state)
output_data = [message.record.data for message in output if message.record]

# Assert that output_data equals expected_records
assert output_data == expected_records

# Collect the intermediate states and records produced before each state
cumulative_records = []
intermediate_states = []
for message in output:
if message.type.value == "RECORD":
record_data = message.record.data
cumulative_records.append(record_data)
elif message.type.value == "STATE":
# Record the state and the records produced before this state
state = message.state
records_before_state = cumulative_records.copy()
intermediate_states.append((state, records_before_state))

# For each intermediate state, perform another read starting from that state
for state, records_before_state in intermediate_states[:-1]:
output_intermediate = _run_read(manifest, config, _stream_name, [state])
records_from_state = [message.record.data for message in output_intermediate if message.record]

# Combine records produced before the state with records from the new read
cumulative_records_state = records_before_state + records_from_state

# Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync.
cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values())

# Compare the cumulative records with the expected records
expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values())
assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted(
expected_records_set, key=lambda x: orjson.dumps(x)
), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}"

# Assert that the final state matches the expected state
final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state]
assert final_state[-1] == expected_state

Expand Down

0 comments on commit 58c96f2

Please sign in to comment.