Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add open/closed range arguments for incremental #1991

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TCursorValue,
LastValueFunc,
OnCursorValueMissing,
TIncrementalRange,
)
from dlt.extract.pipe import Pipe
from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform
Expand Down Expand Up @@ -111,6 +112,8 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa
row_order: Optional[TSortOrder] = None
allow_external_schedulers: bool = False
on_cursor_value_missing: OnCursorValueMissing = "raise"
range_start: TIncrementalRange = "closed"
range_end: TIncrementalRange = "open"

# incremental acting as empty
EMPTY: ClassVar["Incremental[Any]"] = None
Expand All @@ -126,6 +129,8 @@ def __init__(
row_order: Optional[TSortOrder] = None,
allow_external_schedulers: bool = False,
on_cursor_value_missing: OnCursorValueMissing = "raise",
range_start: TIncrementalRange = "closed",
range_end: TIncrementalRange = "open",
) -> None:
# make sure that path is valid
if cursor_path:
Expand Down Expand Up @@ -159,6 +164,8 @@ def __init__(
self._transformers: Dict[str, IncrementalTransform] = {}
self._bound_pipe: SupportsPipe = None
"""Bound pipe"""
self.range_start = range_start
self.range_end = range_end

@property
def primary_key(self) -> Optional[TTableHintTemplate[TColumnNames]]:
Expand All @@ -185,6 +192,8 @@ def _make_transforms(self) -> None:
self._primary_key,
set(self._cached_state["unique_hashes"]),
self.on_cursor_value_missing,
self.range_start,
self.range_end,
)

@classmethod
Expand Down
35 changes: 26 additions & 9 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
IncrementalPrimaryKeyMissing,
IncrementalCursorPathHasValueNone,
)
from dlt.extract.incremental.typing import TCursorValue, LastValueFunc, OnCursorValueMissing
from dlt.extract.incremental.typing import (
TCursorValue,
LastValueFunc,
OnCursorValueMissing,
TIncrementalRange,
)
from dlt.extract.utils import resolve_column_value
from dlt.extract.items import TTableHintTemplate
from dlt.common.schema.typing import TColumnNames
Expand Down Expand Up @@ -57,6 +62,8 @@ def __init__(
primary_key: Optional[TTableHintTemplate[TColumnNames]],
unique_hashes: Set[str],
on_cursor_value_missing: OnCursorValueMissing = "raise",
range_start: TIncrementalRange = "closed",
range_end: TIncrementalRange = "open",
) -> None:
self.resource_name = resource_name
self.cursor_path = cursor_path
Expand All @@ -70,6 +77,8 @@ def __init__(
self.unique_hashes = unique_hashes
self.start_unique_hashes = set(unique_hashes)
self.on_cursor_value_missing = on_cursor_value_missing
self.range_start = range_start
self.range_end = range_end

# compile jsonpath
self._compiled_cursor_path = compile_path(cursor_path)
Expand Down Expand Up @@ -188,10 +197,10 @@ def __call__(
# Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value
if self.end_value is not None:
try:
if (
last_value_func((row_value, self.end_value)) != self.end_value
or last_value_func((row_value,)) == self.end_value
):
if last_value_func((row_value, self.end_value)) != self.end_value:
return None, False, True

if self.range_end == "open" and last_value_func((row_value,)) == self.end_value:
return None, False, True
except Exception as ex:
raise IncrementalCursorInvalidCoercion(
Expand All @@ -218,6 +227,8 @@ def __call__(
) from ex
# new_value is "less" or equal to last_value (the actual max)
if last_value == new_value:
if self.range_start == "open":
return None, False, False
# use func to compute row_value into last_value compatible
processed_row_value = last_value_func((row_value,))
# skip the record that is not a start_value or new_value: that record was already processed
Expand Down Expand Up @@ -311,13 +322,19 @@ def __call__(

if self.last_value_func is max:
compute = pa.compute.max
end_compare = pa.compute.less
last_value_compare = pa.compute.greater_equal
end_compare = pa.compute.less if self.range_end == "open" else pa.compute.less_equal
last_value_compare = (
pa.compute.greater_equal if self.range_start == "closed" else pa.compute.greater
)
new_value_compare = pa.compute.greater
elif self.last_value_func is min:
compute = pa.compute.min
end_compare = pa.compute.greater
last_value_compare = pa.compute.less_equal
end_compare = (
pa.compute.greater if self.range_end == "open" else pa.compute.greater_equal
)
last_value_compare = (
pa.compute.less_equal if self.range_start == "closed" else pa.compute.less
)
new_value_compare = pa.compute.less
else:
raise NotImplementedError(
Expand Down
2 changes: 2 additions & 0 deletions dlt/extract/incremental/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
LastValueFunc = Callable[[Sequence[TCursorValue]], Any]
OnCursorValueMissing = Literal["raise", "include", "exclude"]

TIncrementalRange = Literal["open", "closed"]


class IncrementalColumnState(TypedDict):
initial_value: Optional[Any]
Expand Down
12 changes: 8 additions & 4 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ def __init__(
self.end_value = incremental.end_value
self.row_order: TSortOrder = self.incremental.row_order
self.on_cursor_value_missing = self.incremental.on_cursor_value_missing
self.range_start = self.incremental.range_start
self.range_end = self.incremental.range_end
else:
self.cursor_column = None
self.last_value = None
self.end_value = None
self.row_order = None
self.on_cursor_value_missing = None
self.range_start = None
self.range_end = None

def _make_query(self) -> SelectAny:
table = self.table
Expand All @@ -87,11 +91,11 @@ def _make_query(self) -> SelectAny:

# generate where
if last_value_func is max: # Query ordered and filtered according to last_value function
filter_op = operator.ge
filter_op_end = operator.lt
filter_op = operator.ge if self.range_start == "closed" else operator.gt
filter_op_end = operator.lt if self.range_end == "open" else operator.le
elif last_value_func is min:
filter_op = operator.le
filter_op_end = operator.gt
filter_op = operator.le if self.range_start == "closed" else operator.lt
filter_op_end = operator.gt if self.range_end == "open" else operator.ge
else: # Custom last_value, load everything and let incremental handle filtering
return query # type: ignore[no-any-return]

Expand Down
85 changes: 84 additions & 1 deletion tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime # noqa: I251
from itertools import chain, count
from time import sleep
from typing import Any, Optional
from typing import Any, Optional, Iterable
from unittest import mock

import duckdb
Expand Down Expand Up @@ -1462,6 +1462,7 @@ def some_data(last_timestamp=dlt.sources.incremental("ts", primary_key=())):

@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
def test_apply_hints_incremental(item_type: TestDataItemFormat) -> None:
os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately
p = dlt.pipeline(pipeline_name=uniq_id(), destination="dummy")
data = [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}]
source_items = data_to_item_format(item_type, data)
Expand Down Expand Up @@ -2586,3 +2587,85 @@ def updated_is_int(updated_at=dlt.sources.incremental("updated_at", initial_valu
pipeline.run(updated_is_int())
assert isinstance(pip_ex.value.__cause__, IncrementalCursorInvalidCoercion)
assert pip_ex.value.__cause__.cursor_path == "updated_at"


@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
@pytest.mark.parametrize("last_value_func", [min, max])
def test_start_range_open(item_type: TestDataItemFormat, last_value_func: Any) -> None:
data_range: Iterable[int] = range(1, 12)
if last_value_func == max:
initial_value = 5
# Only items higher than inital extracted
expected_items = list(range(6, 12))
order_dir = "ASC"
elif last_value_func == min:
data_range = reversed(data_range)
initial_value = 5
# Only items lower than inital extracted
expected_items = list(reversed(range(1, 5)))
order_dir = "DESC"

@dlt.resource
def some_data(
updated_at: dlt.sources.incremental[int] = dlt.sources.incremental(
"updated_at",
initial_value=initial_value,
range_start="open",
last_value_func=last_value_func,
),
) -> Any:
data = [{"updated_at": i} for i in data_range]
yield data_to_item_format(item_type, data)

pipeline = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb")
pipeline.run(some_data())

with pipeline.sql_client() as client:
items = [
row[0]
for row in client.execute_sql(
f"SELECT updated_at FROM some_data ORDER BY updated_at {order_dir}"
)
]

assert items == expected_items


@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
@pytest.mark.parametrize("last_value_func", [min, max])
def test_end_range_closed(item_type: TestDataItemFormat, last_value_func: Any) -> None:
values = [5, 10]
expected_items = list(range(5, 11))
if last_value_func == max:
order_dir = "ASC"
elif last_value_func == min:
values = list(reversed(values))
expected_items = list(reversed(expected_items))
order_dir = "DESC"

@dlt.resource
def some_data(
updated_at: dlt.sources.incremental[int] = dlt.sources.incremental(
"updated_at",
initial_value=values[0],
end_value=values[1],
range_end="closed",
last_value_func=last_value_func,
),
) -> Any:
data = [{"updated_at": i} for i in range(1, 12)]
yield data_to_item_format(item_type, data)

pipeline = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb")
pipeline.run(some_data())

with pipeline.sql_client() as client:
items = [
row[0]
for row in client.execute_sql(
f"SELECT updated_at FROM some_data ORDER BY updated_at {order_dir}"
)
]

# Includes values 5-10 inclusive
assert items == expected_items
Loading
Loading