diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 56f2242514..7016316d93 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -174,6 +174,7 @@ MAP_KEY_NAME = "key" MAP_VALUE_NAME = "value" DOC = "doc" +UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"} T = TypeVar("T") @@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz == "UTC" or primitive.tz == "+00:00": + if primitive.tz in UTC_ALIASES: return TimestamptzType() elif primitive.tz is None: return TimestampType() @@ -1073,7 +1074,7 @@ def _task_to_record_batches( arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) batch = arrow_table.to_batches()[0] - yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) + yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) current_index += len(batch) @@ -1278,7 +1279,7 @@ def project_batches( total_row_count += len(batch) -def to_requested_schema( +def _to_requested_schema( requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, @@ -1296,16 +1297,17 @@ def to_requested_schema( class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): - file_schema: Schema + _file_schema: Schema _include_field_ids: bool + _downcast_ns_timestamp_to_us: bool def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None: - self.file_schema = file_schema + self._file_schema = file_schema self._include_field_ids = include_field_ids - self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us + self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: - file_field = self.file_schema.find_field(field.field_id) + file_field = self._file_schema.find_field(field.field_id) if field.field_type.is_primitive: if field.field_type != file_field.field_type: @@ -1313,14 +1315,31 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids) ) elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type: - # Downcasting of nanoseconds to microseconds - if ( - pa.types.is_timestamp(target_type) - and target_type.unit == "us" - and pa.types.is_timestamp(values.type) - and values.type.unit == "ns" - ): - return values.cast(target_type, safe=False) + if field.field_type == TimestampType(): + # Downcasting of nanoseconds to microseconds + if ( + pa.types.is_timestamp(target_type) + and not target_type.tz + and pa.types.is_timestamp(values.type) + and not values.type.tz + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") + elif field.field_type == TimestamptzType(): + if ( + pa.types.is_timestamp(target_type) + and target_type.tz == "UTC" + and pa.types.is_timestamp(values.type) + and values.type.tz in UTC_ALIASES + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: @@ -1971,7 +1990,7 @@ def write_parquet(task: WriteTask) -> DataFile: downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False batches = [ - to_requested_schema( + _to_requested_schema( requested_schema=file_schema, file_schema=table_schema, batch=batch, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 62440c4773..b43dc3206b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) manifest_merge_enabled = PropertyUtil.property_as_bool( self.table_metadata.properties, @@ -545,10 +541,6 @@ def overwrite( _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) diff --git a/tests/conftest.py b/tests/conftest.py index 95e1128af6..6b1a2b43e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table": @pytest.fixture(scope="session") -def arrow_table_date_timestamps_schema() -> Schema: - """Pyarrow table Schema with only date, timestamp and timestamptz values.""" +def table_date_timestamps_schema() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" return Schema( NestedField(field_id=1, name="date", field_type=DateType(), required=False), NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), ) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all supported timestamp types.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="s")), + ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="ms")), + ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="ns")), + ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")), + ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")), + ]) + + +@pytest.fixture(scope="session") +def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table": + """Pyarrow table with all supported timestamp types.""" + import pandas as pd + import pyarrow as pa + + test_data = pd.DataFrame({ + "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_s": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_ms": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_us": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ns": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), + ], + "timestamptz_ns": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_us_etc_utc": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_ns_z": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"), + ], + "timestamptz_s_0000": [ + datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc), + ], + }) + return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all microseconds timestamp.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="us")), + ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="us")), + ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="us")), + ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")), + ]) + + +@pytest.fixture(scope="session") +def table_schema_with_all_microseconds_timestamp_precision() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False), + NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False), + NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False), + NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False), + NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False), + NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False), + NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False), + NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False), + NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False), + NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 984c7d1175..b8fd6d0926 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca assert table_schema == arrow_schema_large +@pytest.mark.integration def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType())) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 12da9c928b..b199f00210 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -461,7 +461,7 @@ def test_append_transform_partition_verify_partitions_count( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, transform: Transform[Any, Any], expected_partitions: Set[Any], format_version: int, @@ -469,7 +469,7 @@ def test_append_transform_partition_verify_partitions_count( # Given part_col = "timestamptz" identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" - nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + nested_field = table_date_timestamps_schema.find_field(part_col) partition_spec = PartitionSpec( PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), ) @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then @@ -510,20 +510,20 @@ def test_append_multiple_partitions( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, format_version: int, ) -> None: # Given identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" partition_spec = PartitionSpec( PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + source_id=table_date_timestamps_schema.find_field("date").field_id, field_id=1001, transform=YearTransform(), name="date_year", ), PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + source_id=table_date_timestamps_schema.find_field("timestamptz").field_id, field_id=1000, transform=HourTransform(), name="timestamptz_hour", @@ -537,7 +537,7 @@ def test_append_multiple_partitions( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index af626718f7..41bc6fb5bf 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -18,11 +18,12 @@ import math import os import time -from datetime import date, datetime, timezone +from datetime import date, datetime from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -977,69 +978,43 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None: +def test_write_all_timestamp_precision( + mocker: MockerFixture, + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, +) -> None: identifier = "default.table_all_timestamp_precision" - arrow_table_schema_with_all_timestamp_precisions = pa.schema([ - ("timestamp_s", pa.timestamp(unit="s")), - ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="ms")), - ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="ns")), - ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), - ]) - TEST_DATA_WITH_NULL = { - "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_s": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ms": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_us": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ns": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - } - input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) tbl = _create_table( session_catalog, identifier, {"format-version": format_version}, - data=[input_arrow_table], + data=[arrow_table_with_all_timestamp_precisions], schema=arrow_table_schema_with_all_timestamp_precisions, ) - tbl.overwrite(input_arrow_table) + tbl.overwrite(arrow_table_with_all_timestamp_precisions) written_arrow_table = tbl.scan().to_arrow() - expected_schema_in_all_us = pa.schema([ - ("timestamp_s", pa.timestamp(unit="us")), - ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="us")), - ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="us")), - ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), - ]) - assert written_arrow_table.schema == expected_schema_in_all_us - assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us) + assert written_arrow_table.schema == arrow_table_schema_with_all_microseconds_timestamp_precisions + assert written_arrow_table == arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if pd.isnull(left): + assert pd.isnull(right) + else: + # Check only upto microsecond precision since Spark loaded dtype is timezone unaware + # and supports upto microsecond precision + assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}" @pytest.mark.integration diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 326eeff195..37198b7edb 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -65,6 +65,7 @@ _determine_partitions, _primitive_to_physical, _read_deletes, + _to_requested_schema, bin_pack_arrow_table, expression_to_pyarrow, project_table, @@ -1889,3 +1890,35 @@ def test_identity_partition_on_multi_columns() -> None: ("n_legs", "ascending"), ("animal", "ascending"), ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) + + +def test__to_requested_schema_timestamps( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) + + expected = arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ).to_batches()[0] + assert result == expected + + +def test__to_requested_schema_timestamps_without_downcast_raises_exception( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + with pytest.raises(ValueError) as exc_info: + _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) + + assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value)