From c004517b96d2cd909cf36fc18506f1d23c289401 Mon Sep 17 00:00:00 2001 From: "ZhengYu, Xu" Date: Sat, 21 Sep 2024 17:08:38 +0800 Subject: [PATCH] feat: Support `compute` functions to accept ChunkedArray. (#100) --- pyarrow-stubs/compute.pyi | 202 +++++++++++++++++++++++++++----------- 1 file changed, 143 insertions(+), 59 deletions(-) diff --git a/pyarrow-stubs/compute.pyi b/pyarrow-stubs/compute.pyi index 7ea1244..a3c2e0f 100644 --- a/pyarrow-stubs/compute.pyi +++ b/pyarrow-stubs/compute.pyi @@ -137,12 +137,17 @@ NumericOrDurationScalar: TypeAlias = NumericScalar | lib.DurationScalar _NumericOrDurationT = TypeVar("_NumericOrDurationT", bound=NumericOrDurationScalar) NumericOrTemporalScalar: TypeAlias = NumericScalar | TemporalScalar _NumericOrTemporalT = TypeVar("_NumericOrTemporalT", bound=NumericOrTemporalScalar) -NumericArray: TypeAlias = lib.NumericArray +NumericArray: TypeAlias = lib.NumericArray[_ScalarT] | lib.ChunkedArray[_ScalarT] _NumericArrayT = TypeVar("_NumericArrayT", bound=lib.NumericArray) -NumericOrDurationArray: TypeAlias = lib.NumericArray | lib.Array[lib.DurationScalar] +NumericOrDurationArray: TypeAlias = ( + lib.NumericArray | lib.Array[lib.DurationScalar] | lib.ChunkedArray +) _NumericOrDurationArrayT = TypeVar("_NumericOrDurationArrayT", bound=NumericOrDurationArray) -NumericOrTemporalArray: TypeAlias = lib.NumericArray | lib.Array[TemporalScalar] +NumericOrTemporalArray: TypeAlias = ( + lib.NumericArray | lib.Array[TemporalScalar] | lib.ChunkedArray[TemporalScalar] +) _NumericOrTemporalArrayT = TypeVar("_NumericOrTemporalArrayT", bound=NumericOrTemporalArray) +BooleanArray: TypeAlias = lib.BooleanArray | lib.ChunkedArray[lib.BooleanScalar] FloatScalar: typing_extensions.TypeAlias = ( lib.Scalar[lib.Float32Type] | lib.Scalar[lib.Float64Type] @@ -155,13 +160,27 @@ FloatArray: typing_extensions.TypeAlias = ( | lib.NumericArray[lib.DoubleScalar] | lib.NumericArray[lib.Decimal128Scalar] | lib.NumericArray[lib.Decimal256Scalar] + | lib.ChunkedArray[lib.FloatScalar] + | lib.ChunkedArray[lib.DoubleScalar] + | lib.ChunkedArray[lib.Decimal128Scalar] + | lib.ChunkedArray[lib.Decimal256Scalar] ) _FloatArrayT = TypeVar("_FloatArrayT", bound=FloatArray) _StringScalarT = TypeVar("_StringScalarT", bound=StringScalar) -StringArray: TypeAlias = lib.StringArray | lib.LargeStringArray +StringArray: TypeAlias = ( + lib.StringArray + | lib.LargeStringArray + | lib.ChunkedArray[lib.StringScalar] + | lib.ChunkedArray[lib.LargeStringScalar] +) _StringArrayT = TypeVar("_StringArrayT", bound=StringArray) _BinaryScalarT = TypeVar("_BinaryScalarT", bound=BinaryScalar) -BinaryArray: TypeAlias = lib.BinaryArray | lib.LargeBinaryArray +BinaryArray: TypeAlias = ( + lib.BinaryArray + | lib.LargeBinaryArray + | lib.ChunkedArray[lib.BinaryScalar] + | lib.ChunkedArray[lib.LargeBinaryScalar] +) _BinaryArrayT = TypeVar("_BinaryArrayT", bound=BinaryArray) StringOrBinaryScalar: TypeAlias = StringScalar | BinaryScalar _StringOrBinaryScalarT = TypeVar("_StringOrBinaryScalarT", bound=StringOrBinaryScalar) @@ -176,6 +195,12 @@ TemporalArray: TypeAlias = ( | lib.TimestampArray | lib.DurationArray | lib.MonthDayNanoIntervalArray + | lib.ChunkedArray[lib.Date32Scalar] + | lib.ChunkedArray[lib.Date64Scalar] + | lib.ChunkedArray[lib.Time32Scalar] + | lib.ChunkedArray[lib.Time64Scalar] + | lib.ChunkedArray[lib.DurationScalar] + | lib.ChunkedArray[lib.MonthDayNanoIntervalScalar] ) _TemporalArrayT = TypeVar("_TemporalArrayT", bound=TemporalArray) _ScalarT = TypeVar("_ScalarT", bound=lib.Scalar) @@ -186,7 +211,7 @@ _ScalarOrArrayT = TypeVar("_ScalarOrArrayT", bound=lib.Array | lib.Scalar) # ========================= 1.1 functions ========================= def all( - array: lib.BooleanScalar | lib.BooleanArray, + array: lib.BooleanScalar | BooleanArray, /, *, skip_nulls: bool = True, @@ -198,7 +223,7 @@ def all( any = _clone_signature(all) def approximate_median( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, *, skip_nulls: bool = True, @@ -207,7 +232,7 @@ def approximate_median( memory_pool: lib.MemoryPool | None = None, ) -> lib.DoubleScalar: ... def count( - array: lib.Array, + array: lib.Array | lib.ChunkedArray, /, mode: Literal["only_valid", "only_null", "all"] = "only_valid", *, @@ -215,7 +240,7 @@ def count( memory_pool: lib.MemoryPool | None = None, ) -> lib.Int64Scalar: ... def count_distinct( - array: lib.Array, + array: lib.Array | lib.ChunkedArray, /, mode: Literal["only_valid", "only_null", "all"] = "only_valid", *, @@ -223,7 +248,7 @@ def count_distinct( memory_pool: lib.MemoryPool | None = None, ) -> lib.Int64Scalar: ... def first( - array: lib.Array[_ScalarT], + array: lib.Array[_ScalarT] | lib.ChunkedArray[_ScalarT], /, *, skip_nulls: bool = True, @@ -232,7 +257,7 @@ def first( memory_pool: lib.MemoryPool | None = None, ) -> _ScalarT: ... def first_last( - array: lib.Array, + array: lib.Array | lib.ChunkedArray, /, *, skip_nulls: bool = True, @@ -241,7 +266,7 @@ def first_last( memory_pool: lib.MemoryPool | None = None, ) -> lib.StructScalar: ... def index( - data: lib.Array, + data: lib.Array | lib.ChunkedArray, value, start: int | None = None, end: int | None = None, @@ -255,7 +280,7 @@ min = _clone_signature(first) min_max = _clone_signature(first_last) def mean( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, *, skip_nulls: bool = True, @@ -264,7 +289,7 @@ def mean( memory_pool: lib.MemoryPool | None = None, ) -> lib.DoubleScalar | lib.Decimal128Scalar: ... def mode( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, n: int = 1, *, @@ -283,7 +308,7 @@ def product( memory_pool: lib.MemoryPool | None = None, ) -> _ScalarT: ... def quantile( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, q: float = 0.5, *, @@ -294,7 +319,7 @@ def quantile( memory_pool: lib.MemoryPool | None = None, ) -> lib.DoubleArray: ... def stddev( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, *, ddof: float = 0, @@ -304,7 +329,7 @@ def stddev( memory_pool: lib.MemoryPool | None = None, ) -> lib.DoubleScalar: ... def sum( - array: _NumericScalarT | lib.NumericArray[_NumericScalarT], + array: _NumericScalarT | NumericArray[_NumericScalarT], /, *, skip_nulls: bool = True, @@ -313,7 +338,7 @@ def sum( memory_pool: lib.MemoryPool | None = None, ) -> _NumericScalarT: ... def tdigest( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, q: float = 0.5, *, @@ -325,7 +350,7 @@ def tdigest( memory_pool: lib.MemoryPool | None = None, ) -> lib.DoubleArray: ... def variance( - array: NumericScalar | lib.NumericArray, + array: NumericScalar | NumericArray, /, *, ddof: int = 0, @@ -867,8 +892,8 @@ def equal( ) -> lib.BooleanScalar: ... @overload def equal( - x: lib.Scalar | lib.Array, - y: lib.Scalar | lib.Array, + x: lib.Scalar | lib.Array | lib.ChunkedArray, + y: lib.Scalar | lib.Array | lib.ChunkedArray, /, *, memory_pool: lib.MemoryPool | None = None, @@ -919,8 +944,8 @@ def and_( ) -> lib.BooleanScalar: ... @overload def and_( - x: lib.BooleanScalar | lib.BooleanArray, - y: lib.BooleanScalar | lib.BooleanArray, + x: lib.BooleanScalar | BooleanArray, + y: lib.BooleanScalar | BooleanArray, /, *, memory_pool: lib.MemoryPool | None = None, @@ -1026,11 +1051,20 @@ def binary_length( ) -> lib.Int64Scalar: ... @overload def binary_length( - strings: lib.BinaryArray | lib.StringArray, /, *, memory_pool: lib.MemoryPool | None = None + strings: lib.BinaryArray + | lib.StringArray + | lib.ChunkedArray[lib.BinaryScalar] + | lib.ChunkedArray[lib.StringScalar], + /, + *, + memory_pool: lib.MemoryPool | None = None, ) -> lib.Int32Array: ... @overload def binary_length( - strings: lib.LargeBinaryArray | lib.LargeStringArray, + strings: lib.LargeBinaryArray + | lib.LargeStringArray + | lib.ChunkedArray[lib.LargeBinaryScalar] + | lib.ChunkedArray[lib.LargeStringScalar], /, *, memory_pool: lib.MemoryPool | None = None, @@ -1180,11 +1214,14 @@ def utf8_length( ) -> lib.Int64Scalar: ... @overload def utf8_length( - strings: lib.StringArray, /, *, memory_pool: lib.MemoryPool | None = None + strings: lib.StringArray | lib.ChunkedArray[lib.StringScalar], + /, + *, + memory_pool: lib.MemoryPool | None = None, ) -> lib.Int32Array: ... @overload def utf8_length( - strings: lib.LargeStringArray, + strings: lib.LargeStringArray | lib.ChunkedArray[lib.LargeStringScalar], /, *, memory_pool: lib.MemoryPool | None = None, @@ -1562,7 +1599,10 @@ def count_substring( ) -> lib.Int64Scalar: ... @overload def count_substring( - strings: lib.StringArray | lib.BinaryArray, + strings: lib.StringArray + | lib.BinaryArray + | lib.ChunkedArray[lib.StringScalar] + | lib.ChunkedArray[lib.BinaryScalar], /, pattern: str, *, @@ -1572,7 +1612,10 @@ def count_substring( ) -> lib.Int32Array: ... @overload def count_substring( - strings: lib.LargeStringArray | lib.LargeBinaryArray, + strings: lib.LargeStringArray + | lib.LargeBinaryArray + | lib.ChunkedArray[lib.LargeStringScalar] + | lib.ChunkedArray[lib.LargeBinaryScalar], /, pattern: str, *, @@ -1631,7 +1674,7 @@ find_substring_regex = _clone_signature(count_substring) def index_in( values: lib.Scalar, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1639,9 +1682,9 @@ def index_in( ) -> lib.Int32Scalar: ... @overload def index_in( - values: lib.Array, + values: lib.Array | lib.ChunkedArray, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1651,7 +1694,7 @@ def index_in( def index_in( values: Expression, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1661,7 +1704,7 @@ def index_in( def is_in( values: lib.Scalar, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1669,9 +1712,9 @@ def is_in( ) -> lib.BooleanScalar: ... @overload def is_in( - values: lib.Array, + values: lib.Array | lib.ChunkedArray, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1681,7 +1724,7 @@ def is_in( def is_in( values: Expression, /, - value_set: lib.Array, + value_set: lib.Array | lib.ChunkedArray, *, skip_nulls: bool = False, options: SetLookupOptions | None = None, @@ -1721,7 +1764,7 @@ def is_null( ) -> lib.BooleanScalar: ... @overload def is_null( - values: lib.Array, + values: lib.Array | lib.ChunkedArray, /, *, nan_is_null: bool = False, @@ -1743,7 +1786,7 @@ def is_valid( ) -> lib.BooleanScalar: ... @overload def is_valid( - values: lib.Array, /, *, memory_pool: lib.MemoryPool | None = None + values: lib.Array | lib.ChunkedArray, /, *, memory_pool: lib.MemoryPool | None = None ) -> lib.BooleanArray: ... @overload def is_valid( @@ -1764,14 +1807,14 @@ def if_else(cond, left, right, /, *, memory_pool: lib.MemoryPool | None = None): @overload def list_value_length( - lists: lib.ListArray | lib.ListViewArray | lib.FixedSizeListArray, + lists: lib.ListArray | lib.ListViewArray | lib.FixedSizeListArray | lib.ChunkedArray, /, *, memory_pool: lib.MemoryPool | None = None, ) -> lib.Int32Array: ... @overload def list_value_length( - lists: lib.LargeListArray | lib.LargeListViewArray, + lists: lib.LargeListArray | lib.LargeListViewArray | lib.ChunkedArray, /, *, memory_pool: lib.MemoryPool | None = None, @@ -1794,7 +1837,7 @@ def make_struct( ) -> lib.StructScalar: ... @overload def make_struct( - *args: lib.Array, + *args: lib.Array | lib.ChunkedArray, field_names: list[str] | tuple[str, ...] = (), field_nullability: bool | None = None, field_metadata: list[lib.KeyValueMetadata] | None = None, @@ -1908,6 +1951,14 @@ def cast( memory_pool: lib.MemoryPool | None = None, ) -> lib.Array[lib.Scalar[_DataTypeT]]: ... @overload +def cast( + arr: lib.ChunkedArray, + target_type: _DataTypeT, + safe: bool | None = None, + options: CastOptions | None = None, + memory_pool: lib.MemoryPool | None = None, +) -> lib.ChunkedArray[lib.Scalar[_DataTypeT]]: ... +@overload def strftime( timestamps: TemporalScalar, /, @@ -2024,7 +2075,12 @@ def hour( ) -> lib.Int64Scalar: ... @overload def hour( - values: lib.TimestampArray | lib.Time32Array | lib.Time64Array, + values: lib.TimestampArray + | lib.Time32Array + | lib.Time64Array + | lib.ChunkedArray[lib.TimestampScalar] + | lib.ChunkedArray[lib.Time32Scalar] + | lib.ChunkedArray[lib.Time64Scalar], /, *, memory_pool: lib.MemoryPool | None = None, @@ -2042,7 +2098,10 @@ def is_dst( ) -> lib.BooleanScalar: ... @overload def is_dst( - values: lib.TimestampArray, /, *, memory_pool: lib.MemoryPool | None = None + values: lib.TimestampArray | lib.ChunkedArray[lib.TimestampScalar], + /, + *, + memory_pool: lib.MemoryPool | None = None, ) -> lib.BooleanArray: ... @overload def is_dst(values: Expression, /, *, memory_pool: lib.MemoryPool | None = None) -> Expression: ... @@ -2052,7 +2111,10 @@ def iso_week( ) -> lib.Int64Scalar: ... @overload def iso_week( - values: lib.TimestampArray, /, *, memory_pool: lib.MemoryPool | None = None + values: lib.TimestampArray | lib.ChunkedArray[lib.TimestampScalar], + /, + *, + memory_pool: lib.MemoryPool | None = None, ) -> lib.Int64Array: ... @overload def iso_week( @@ -2070,7 +2132,12 @@ def is_leap_year( ) -> lib.BooleanScalar: ... @overload def is_leap_year( - values: lib.TimestampArray | lib.Date32Array | lib.Date64Array, + values: lib.TimestampArray + | lib.Date32Array + | lib.Date64Array + | lib.ChunkedArray[lib.TimestampScalar] + | lib.ChunkedArray[lib.Date32Scalar] + | lib.ChunkedArray[lib.Date64Scalar], /, *, memory_pool: lib.MemoryPool | None = None, @@ -2108,7 +2175,7 @@ def week( ) -> lib.Int64Scalar: ... @overload def week( - values: lib.TimestampArray, + values: lib.TimestampArray | lib.ChunkedArray[lib.TimestampScalar], /, *, week_starts_monday: bool = True, @@ -2188,7 +2255,7 @@ def assume_timezone( ) -> lib.TimestampScalar: ... @overload def assume_timezone( - timestamps: lib.TimestampArray, + timestamps: lib.TimestampArray | lib.ChunkedArray[lib.TimestampScalar], /, timezone: str, *, @@ -2214,7 +2281,10 @@ def local_timestamp( ) -> lib.TimestampScalar: ... @overload def local_timestamp( - timestamps: lib.TimestampArray, /, *, memory_pool: lib.MemoryPool | None = None + timestamps: lib.TimestampArray | lib.ChunkedArray[lib.TimestampScalar], + /, + *, + memory_pool: lib.MemoryPool | None = None, ) -> lib.TimestampArray: ... @overload def local_timestamp( @@ -2287,7 +2357,7 @@ def unique(array: _ArrayT, /, *, memory_pool: lib.MemoryPool | None = None) -> _ def unique(array: Expression, /, *, memory_pool: lib.MemoryPool | None = None) -> Expression: ... @overload def value_counts( - array: lib.Array, /, *, memory_pool: lib.MemoryPool | None = None + array: lib.Array | lib.ChunkedArray, /, *, memory_pool: lib.MemoryPool | None = None ) -> lib.StructArray: ... @overload def value_counts( @@ -2298,7 +2368,7 @@ def value_counts( @overload def array_filter( array: _ArrayT, - selection_filter: list[bool] | list[bool | None] | lib.BooleanArray, + selection_filter: list[bool] | list[bool | None] | BooleanArray, /, null_selection_behavior: Literal["drop", "emit_null"] = "drop", *, @@ -2308,7 +2378,7 @@ def array_filter( @overload def array_filter( array: Expression, - selection_filter: list[bool] | list[bool | None] | lib.BooleanArray, + selection_filter: list[bool] | list[bool | None] | BooleanArray, /, null_selection_behavior: Literal["drop", "emit_null"] = "drop", *, @@ -2318,7 +2388,14 @@ def array_filter( @overload def array_take( array: _ArrayT, - indices: list[int] | list[int | None] | lib.Int16Array | lib.Int32Array | lib.Int64Array, + indices: list[int] + | list[int | None] + | lib.Int16Array + | lib.Int32Array + | lib.Int64Array + | lib.ChunkedArray[lib.Int16Scalar] + | lib.ChunkedArray[lib.Int32Scalar] + | lib.ChunkedArray[lib.Int64Scalar], /, *, boundscheck: bool = True, @@ -2328,7 +2405,14 @@ def array_take( @overload def array_take( array: Expression, - indices: list[int] | list[int | None] | lib.Int16Array | lib.Int32Array | lib.Int64Array, + indices: list[int] + | list[int | None] + | lib.Int16Array + | lib.Int32Array + | lib.Int64Array + | lib.ChunkedArray[lib.Int16Scalar] + | lib.ChunkedArray[lib.Int32Scalar] + | lib.ChunkedArray[lib.Int64Scalar], /, *, boundscheck: bool = True, @@ -2368,7 +2452,7 @@ def indices_nonzero( # ========================= 3.5 Sorts and partitions ========================= @overload def array_sort_indices( - array: lib.Array, + array: lib.Array | lib.ChunkedArray, /, order: Literal["ascending", "descending"] = "ascending", *, @@ -2388,7 +2472,7 @@ def array_sort_indices( ) -> Expression: ... @overload def partition_nth_indices( - array: lib.Array, + array: lib.Array | lib.ChunkedArray, /, pivot: int, *, @@ -2407,7 +2491,7 @@ def partition_nth_indices( memory_pool: lib.MemoryPool | None = None, ) -> Expression: ... def rank( - input: lib.Array, + input: lib.Array | lib.ChunkedArray, /, sort_keys: Literal["ascending", "descending"] = "ascending", *, @@ -2418,7 +2502,7 @@ def rank( ) -> lib.UInt64Array: ... @overload def select_k_unstable( - input: lib.Array, + input: lib.Array | lib.ChunkedArray, /, k: int, sort_keys: list[tuple[str, Literal["ascending", "descending"]]], @@ -2537,7 +2621,7 @@ def fill_null_backward(values, /, *, memory_pool: lib.MemoryPool | None = None): def fill_null_forward(values, /, *, memory_pool: lib.MemoryPool | None = None): ... def replace_with_mask( values, - mask: list[bool] | list[bool | None] | lib.BooleanArray, + mask: list[bool] | list[bool | None] | BooleanArray, replacements, /, *,