Skip to content

Commit

Permalink
support PyArrow timestamptz with Etc/UTC (#910)
Browse files Browse the repository at this point in the history
Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
sungwy and Fokko authored Jul 12, 2024
1 parent f6d56e9 commit 32e8f88
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 86 deletions.
51 changes: 35 additions & 16 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
DOC = "doc"
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}

T = TypeVar("T")

Expand Down Expand Up @@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz == "UTC" or primitive.tz == "+00:00":
if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
Expand Down Expand Up @@ -1073,7 +1074,7 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
current_index += len(batch)


Expand Down Expand Up @@ -1278,7 +1279,7 @@ def project_batches(
total_row_count += len(batch)


def to_requested_schema(
def _to_requested_schema(
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
Expand All @@ -1296,31 +1297,49 @@ def to_requested_schema(


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_file_schema: Schema
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
file_field = self._file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
return values.cast(target_type, safe=False)
if field.field_type == TimestampType():
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and not target_type.tz
and pa.types.is_timestamp(values.type)
and not values.type.tz
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
elif field.field_type == TimestamptzType():
if (
pa.types.is_timestamp(target_type)
and target_type.tz == "UTC"
and pa.types.is_timestamp(values.type)
and values.type.tz in UTC_ALIASES
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down Expand Up @@ -1970,7 +1989,7 @@ def write_parquet(task: WriteTask) -> DataFile:

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
to_requested_schema(
_to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
batch=batch,
Expand Down
8 changes: 0 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
Expand Down Expand Up @@ -545,10 +541,6 @@ def overwrite(
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)

Expand Down
116 changes: 114 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":


@pytest.fixture(scope="session")
def arrow_table_date_timestamps_schema() -> Schema:
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
def table_date_timestamps_schema() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all supported timestamp types."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="ms")),
("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
])


@pytest.fixture(scope="session")
def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table":
"""Pyarrow table with all supported timestamp types."""
import pandas as pd
import pyarrow as pa

test_data = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_ms": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_us": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_etc_utc": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_ns_z": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"),
],
"timestamptz_s_0000": [
datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
],
})
return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all microseconds timestamp."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="us")),
("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="us")),
("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
])


@pytest.fixture(scope="session")
def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False),
NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False),
NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False),
NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False),
NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False),
NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False),
NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False),
NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False),
NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False),
NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False),
NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False),
)
1 change: 1 addition & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
assert table_schema == arrow_schema_large


@pytest.mark.integration
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,15 @@ def test_append_transform_partition_verify_partitions_count(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
transform: Transform[Any, Any],
expected_partitions: Set[Any],
format_version: int,
) -> None:
# Given
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
nested_field = table_date_timestamps_schema.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
)
Expand All @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down Expand Up @@ -510,20 +510,20 @@ def test_append_multiple_partitions(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
format_version: int,
) -> None:
# Given
identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions"
partition_spec = PartitionSpec(
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
source_id=table_date_timestamps_schema.find_field("date").field_id,
field_id=1001,
transform=YearTransform(),
name="date_year",
),
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
field_id=1000,
transform=HourTransform(),
name="timestamptz_hour",
Expand All @@ -537,7 +537,7 @@ def test_append_multiple_partitions(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down
Loading

0 comments on commit 32e8f88

Please sign in to comment.