Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Appends with TimeTransform Partitions #784

Merged
merged 22 commits into from
May 31, 2024
Merged
2 changes: 1 addition & 1 deletion pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError("partition_fields must contain exactly one field.")
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
partition_field = partition_fields[0]
iceberg_typed_key_values[partition_field.name] = partition_record_value(
partition_field=partition_field,
Expand Down
67 changes: 29 additions & 38 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

supported_transforms = {IdentityTransform}
if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields):
if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}."
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
Expand Down Expand Up @@ -3643,33 +3644,6 @@ class TablePartition:
arrow_table_partition: pa.Table


def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
order = "ascending" if not reverse else "descending"
null_placement = "at_start" if reverse else "at_end"
return {"sort_keys": [(column_name, order) for column_name in partition_columns], "null_placement": null_placement}


def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
"""Given a table, sort it by current partition scheme."""
# only works for identity for now
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options["sort_keys"], null_placement=sort_options["null_placement"])
return sorted_arrow_table


def get_partition_columns(
spec: PartitionSpec,
schema: Schema,
) -> list[str]:
partition_cols = []
for partition_field in spec.fields:
column_name = schema.find_column_name(partition_field.source_id)
if not column_name:
raise ValueError(f"{partition_field=} could not be found in {schema}.")
partition_cols.append(column_name)
return partition_cols


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
Expand Down Expand Up @@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
"""
import pyarrow as pa

partition_columns = get_partition_columns(spec=spec, schema=schema)
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)

reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()

slice_instructions: list[dict[str, Any]] = []
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
Expand All @@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
99 changes: 98 additions & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
from typing import Literal as LiteralType
from uuid import UUID

Expand Down Expand Up @@ -82,6 +82,9 @@
from pyiceberg.utils.parsing import ParseNumberFromBrackets
from pyiceberg.utils.singleton import Singleton

if TYPE_CHECKING:
import pyarrow as pa

S = TypeVar("S")
T = TypeVar("T")

Expand Down Expand Up @@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool:
return self.root == other.root
return False

@property
def supports_pyarrow_transform(self) -> bool:
return False

@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -290,6 +300,9 @@ def __repr__(self) -> str:
"""Return the string representation of the BucketTransform class."""
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


class TimeResolution(IntEnum):
YEAR = 6
Expand Down Expand Up @@ -349,6 +362,10 @@ def dedup_name(self) -> str:
def preserves_order(self) -> bool:
return True

@property
def supports_pyarrow_transform(self) -> bool:
return True


class YearTransform(TimeTransform[S]):
"""Transforms a datetime value into a year value.
Expand Down Expand Up @@ -391,6 +408,21 @@ def __repr__(self) -> str:
"""Return the string representation of the YearTransform class."""
return "YearTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply year transform for type: {source}")

return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None


class MonthTransform(TimeTransform[S]):
"""Transforms a datetime value into a month value.
Expand Down Expand Up @@ -433,6 +465,27 @@ def __repr__(self) -> str:
"""Return the string representation of the MonthTransform class."""
return "MonthTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply month transform for type: {source}")

def month_func(v: pa.Array) -> pa.Array:
return pc.add(
pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)),
pc.add(pc.month(v), pa.scalar(-1)),
)

return lambda v: month_func(v) if v is not None else None


class DayTransform(TimeTransform[S]):
"""Transforms a datetime value into a day value.
Expand Down Expand Up @@ -478,6 +531,21 @@ def __repr__(self) -> str:
"""Return the string representation of the DayTransform class."""
return "DayTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply day transform for type: {source}")

return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None


class HourTransform(TimeTransform[S]):
"""Transforms a datetime value into a hour value.
Expand Down Expand Up @@ -515,6 +583,19 @@ def __repr__(self) -> str:
"""Return the string representation of the HourTransform class."""
return "HourTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply hour transform for type: {source}")

return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None


def _base64encode(buffer: bytes) -> str:
"""Convert bytes to base64 string."""
Expand Down Expand Up @@ -585,6 +666,13 @@ def __repr__(self) -> str:
"""Return the string representation of the IdentityTransform class."""
return "IdentityTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
return lambda v: v

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
Expand Down Expand Up @@ -725,6 +813,9 @@ def __repr__(self) -> str:
"""Return the string representation of the TruncateTransform class."""
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


@singledispatch
def _human_string(value: Any, _type: IcebergType) -> str:
Expand Down Expand Up @@ -807,6 +898,9 @@ def __repr__(self) -> str:
"""Return the string representation of the UnknownTransform class."""
return f"UnknownTransform(transform={repr(self._transform)})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


class VoidTransform(Transform[S, None], Singleton):
"""A transform that always returns None."""
Expand Down Expand Up @@ -835,6 +929,9 @@ def __repr__(self) -> str:
"""Return the string representation of the VoidTransform class."""
return "VoidTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


def _truncate_number(
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
import pyarrow as pa

return pa.Table.from_pylist([{}, {}], schema=pa_schema)


@pytest.fixture(scope="session")
def arrow_table_date_timestamps() -> "pa.Table":
"""Pyarrow table with only date, timestamp and timestamptz values."""
import pyarrow as pa

return pa.Table.from_pydict(
{
"date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None],
"timestamp": [
datetime(2023, 12, 31, 0, 0, 0),
datetime(2024, 1, 1, 0, 0, 0),
datetime(2024, 1, 31, 0, 0, 0),
datetime(2024, 2, 1, 0, 0, 0),
datetime(2024, 2, 1, 6, 0, 0),
None,
],
"timestamptz": [
datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc),
None,
],
},
schema=pa.schema([
("date", pa.date32()),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
]),
)


@pytest.fixture(scope="session")
def arrow_table_date_timestamps_schema() -> Schema:
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
)
Loading