From dceedfac4ec072ee4da99bf02dc93c1d27be45a9 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 20:32:14 -0400 Subject: [PATCH] Check if schema is compatible in `add_files` API (#907) Co-authored-by: Fokko Driesprong --- pyiceberg/io/pyarrow.py | 45 ++++++++++++++ pyiceberg/table/__init__.py | 62 ++++--------------- tests/integration/test_add_files.py | 85 +++++++++++++++++++------- tests/io/test_pyarrow.py | 91 ++++++++++++++++++++++++++++ tests/table/test_init.py | 92 ----------------------------- 5 files changed, 211 insertions(+), 164 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 142e9e5f08..56f2242514 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2032,6 +2032,49 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ return bin_packed_record_batches +def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: + """ + Check if the `table_schema` is compatible with `other_schema`. + + Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. + + Raises: + ValueError: If the schemas are not compatible. + """ + name_mapping = table_schema.name_mapping + try: + task_schema = pyarrow_to_schema( + other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + except ValueError as e: + other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + additional_names = set(other_schema.column_names) - set(table_schema.column_names) + raise ValueError( + f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." + ) from e + + if table_schema.as_struct() != task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + + for lhs in table_schema.fields: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing") + + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: for file_path in file_paths: input_file = io.new_input(file_path) @@ -2043,6 +2086,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_ f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" ) schema = table_metadata.schema() + _check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) + statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=parquet_metadata, stats_columns=compute_statistics_plan(schema, table_metadata.properties), diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5342d37053..62440c4773 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -73,7 +73,7 @@ manifest_evaluator, ) from pyiceberg.io import FileIO, OutputFile, load_file_io -from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table +from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table from pyiceberg.manifest import ( POSITIONAL_DELETE_SCHEMA, DataFile, @@ -166,54 +166,8 @@ ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 -DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" _JAVA_LONG_MAX = 9223372036854775807 - - -def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None: - """ - Check if the `table_schema` is compatible with `other_schema`. - - Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. - - Raises: - ValueError: If the schemas are not compatible. - """ - from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - name_mapping = table_schema.name_mapping - try: - task_schema = pyarrow_to_schema( - other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) - additional_names = set(other_schema.column_names) - set(table_schema.column_names) - raise ValueError( - f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." - ) from e - - if table_schema.as_struct() != task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable - - console = Console(record=True) - - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") - - for lhs in table_schema.fields: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) - except ValueError: - rich_table.add_row("❌", str(lhs), "Missing") - - console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") +DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" class TableProperties: @@ -526,8 +480,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) raise ValueError( f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) - - _check_schema_compatible(self._table.schema(), other_schema=df.schema) + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_schema_compatible( + self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) # cast if the two schemas are compatible but not equal table_arrow_schema = self._table.schema().as_arrow() if table_arrow_schema != df.schema: @@ -585,8 +541,10 @@ def overwrite( raise ValueError( f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) - - _check_schema_compatible(self._table.schema(), other_schema=df.schema) + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_schema_compatible( + self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) # cast if the two schemas are compatible but not equal table_arrow_schema = self._table.schema().as_arrow() if table_arrow_schema != df.schema: diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 825d17e924..984c7d1175 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -17,6 +17,7 @@ # pylint:disable=redefined-outer-name import os +import re from datetime import date from typing import Iterator @@ -463,6 +464,57 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat assert summary["snapshot_prop_a"] == "test_prop_a" +@pytest.mark.integration +def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.table_schema_mismatch_fails_v{format_version}" + + tbl = _create_table(session_catalog, identifier, format_version) + WRONG_SCHEMA = pa.schema([ + ("foo", pa.bool_()), + ("bar", pa.string()), + ("baz", pa.string()), # should be integer + ("qux", pa.date32()), + ]) + file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": "123", + "qux": date(2024, 3, 7), + }, + { + "foo": True, + "bar": "bar_string", + "baz": "124", + "qux": date(2024, 3, 7), + }, + ], + schema=WRONG_SCHEMA, + ) + ) + + expected = """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │ +| ✅ │ 2: bar: optional string │ 2: bar: optional string │ +│ ❌ │ 3: baz: optional int │ 3: baz: optional string │ +│ ✅ │ 4: qux: optional date │ 4: qux: optional date │ +└────┴──────────────────────────┴──────────────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + tbl.add_files(file_paths=[file_path]) + + @pytest.mark.integration def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: identifier = f"default.unpartitioned_with_large_types{format_version}" @@ -518,7 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca assert table_schema == arrow_schema_large -def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: +def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType())) nanoseconds_schema = pa.schema([ @@ -549,25 +601,18 @@ def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_versi partition_spec=PartitionSpec(), ) - file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)] + file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet" # write parquet files - for file_path in file_paths: - fo = tbl.io.new_output(file_path) - with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer: - writer.write_table(arrow_table) + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer: + writer.write_table(arrow_table) # add the parquet files as data files - tbl.add_files(file_paths=file_paths) - - assert tbl.scan().to_arrow() == pa.concat_tables( - [ - arrow_table.cast( - pa.schema([ - ("quux", pa.timestamp("us", tz="UTC")), - ]), - safe=False, - ) - ] - * 5 - ) + with pytest.raises( + TypeError, + match=re.escape( + "Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write." + ), + ): + tbl.add_files(file_paths=[file_path]) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 1b9468993c..326eeff195 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -60,6 +60,7 @@ PyArrowFile, PyArrowFileIO, StatsAggregator, + _check_schema_compatible, _ConvertToArrowSchema, _determine_partitions, _primitive_to_physical, @@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: assert len(list(bin_packed)) == 5 +def test_schema_mismatch_type(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.decimal128(18, 6), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + expected = r"""Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴─────────────────────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + _check_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + )) + + expected = """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴──────────────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + _check_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + )) + + expected = """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ Missing │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴──────────────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + _check_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("new_field", pa.date32(), nullable=True), + )) + + expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." + + with pytest.raises(ValueError, match=expected): + _check_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_downcast(table_schema_simple: Schema) -> None: + # large_string type is compatible with string type + other_schema = pa.schema(( + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema`") + + def test_partition_for_demo() -> None: test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema( diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 31a8bbf444..7a5ea86d7a 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -19,7 +19,6 @@ from copy import copy from typing import Any, Dict -import pyarrow as pa import pytest from pydantic import ValidationError from sortedcontainers import SortedList @@ -63,7 +62,6 @@ TableIdentifier, UpdateSchema, _apply_table_update, - _check_schema_compatible, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, @@ -1122,96 +1120,6 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) -def test_schema_mismatch_type(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.decimal128(18, 6), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) - - expected = r"""Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴─────────────────────────────────┘ -""" - - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) - - -def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - )) - - expected = """Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: optional int │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴──────────────────────────┘ -""" - - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) - - -def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - )) - - expected = """Mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ Missing │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴──────────────────────────┘ -""" - - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) - - -def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - pa.field("new_field", pa.date32(), nullable=True), - )) - - expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." - - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) - - -def test_schema_downcast(table_schema_simple: Schema) -> None: - # large_string type is compatible with string type - other_schema = pa.schema(( - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) - - try: - _check_schema_compatible(table_schema_simple, other_schema) - except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") - - def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None: # metadata properties are all strings for k, v in example_table_metadata_v2["properties"].items():