Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support PyArrow timestamptz with Etc/UTC #910

Merged
merged 9 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
DOC = "doc"
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}

T = TypeVar("T")

Expand Down Expand Up @@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz == "UTC" or primitive.tz == "+00:00":
if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
Expand Down Expand Up @@ -1073,7 +1074,7 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
current_index += len(batch)


Expand Down Expand Up @@ -1278,7 +1279,7 @@ def project_batches(
total_row_count += len(batch)


def to_requested_schema(
def _to_requested_schema(
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
Expand All @@ -1296,31 +1297,49 @@ def to_requested_schema(


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_file_schema: Schema
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
file_field = self._file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
return values.cast(target_type, safe=False)
if field.field_type == TimestampType():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stricter and clearer field_type driven checks

# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and not target_type.tz
and pa.types.is_timestamp(values.type)
and not values.type.tz
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
elif field.field_type == TimestamptzType():
if (
pa.types.is_timestamp(target_type)
and target_type.tz == "UTC"
and pa.types.is_timestamp(values.type)
and values.type.tz in UTC_ALIASES
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down Expand Up @@ -1971,7 +1990,7 @@ def write_parquet(task: WriteTask) -> DataFile:

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
to_requested_schema(
_to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
batch=batch,
Expand Down
8 changes: 0 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this cast function - to_requested_schema should be responsible for casting the types to their desired schema, instead of casting it here


manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
Expand Down Expand Up @@ -545,10 +541,6 @@ def overwrite(
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)

Expand Down
116 changes: 114 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":


@pytest.fixture(scope="session")
def arrow_table_date_timestamps_schema() -> Schema:
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
def table_date_timestamps_schema() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all supported timestamp types."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="ms")),
("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
])


@pytest.fixture(scope="session")
def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table":
"""Pyarrow table with all supported timestamp types."""
import pandas as pd
import pyarrow as pa

test_data = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_ms": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_us": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_etc_utc": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_ns_z": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"),
],
"timestamptz_s_0000": [
datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
],
})
return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all microseconds timestamp."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="us")),
("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="us")),
("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
])


@pytest.fixture(scope="session")
def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False),
NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False),
NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False),
NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False),
NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False),
NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False),
NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False),
NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False),
NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False),
NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False),
NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False),
)
1 change: 1 addition & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
assert table_schema == arrow_schema_large


@pytest.mark.integration
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,15 @@ def test_append_transform_partition_verify_partitions_count(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
transform: Transform[Any, Any],
expected_partitions: Set[Any],
format_version: int,
) -> None:
# Given
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
nested_field = table_date_timestamps_schema.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
)
Expand All @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down Expand Up @@ -510,20 +510,20 @@ def test_append_multiple_partitions(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
format_version: int,
) -> None:
# Given
identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions"
partition_spec = PartitionSpec(
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
source_id=table_date_timestamps_schema.find_field("date").field_id,
field_id=1001,
transform=YearTransform(),
name="date_year",
),
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
field_id=1000,
transform=HourTransform(),
name="timestamptz_hour",
Expand All @@ -537,7 +537,7 @@ def test_append_multiple_partitions(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down
Loading