Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if schema is compatible in add_files API #907

Merged
merged 10 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,6 +2032,49 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
return bin_packed_record_batches


def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
for file_path in file_paths:
input_file = io.new_input(file_path)
Expand All @@ -2043,6 +2086,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understand is that now if we enable downcast-ns-timestamp-to-us-on-write, we allow user to add parquet files with TIMESTAMP_NANOS type data. My concern here is that we may add parquet files that not align with spec, which states that timestamp/timstamptz type should map to TIMESTAMP_MICROS. Shall we be more restrictive when checking the parquet file that will be directly added to the table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch @HonahX - I think you are right. Adding a nanosecond timestamp file doesn't correctly allow Spark Iceberg to read the file and instead results in exceptions like:

ValueError: year 53177 is out of range

I will make downcast_ns_timestamp_to_us_on_write an input argument to _check_schema_compatible, so that we can prevent nanoseconds timestamp types from being added through add_files, but can continue to support it being downcast in overwrite/append

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think it is okay to be able to read more broadly. We do need tests to ensure that it works correctly. Looking at Arrow, there are already some physical types that we don't support (date64, etc). In Java, we do support reading Timestamps that are encoded using INT96, we should not produce them.


statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
Expand Down
62 changes: 10 additions & 52 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -166,54 +166,8 @@

ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"


class TableProperties:
Expand Down Expand Up @@ -526,8 +480,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
Expand Down Expand Up @@ -585,8 +541,10 @@ def overwrite(
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
Expand Down
85 changes: 65 additions & 20 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint:disable=redefined-outer-name

import os
import re
from datetime import date
from typing import Iterator

Expand Down Expand Up @@ -463,6 +464,57 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
assert summary["snapshot_prop_a"] == "test_prop_a"


@pytest.mark.integration
def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.table_schema_mismatch_fails_v{format_version}"

tbl = _create_table(session_catalog, identifier, format_version)
WRONG_SCHEMA = pa.schema([
("foo", pa.bool_()),
("bar", pa.string()),
("baz", pa.string()), # should be integer
("qux", pa.date32()),
])
file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet"
# write parquet files
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer:
writer.write_table(
pa.Table.from_pylist(
[
{
"foo": True,
"bar": "bar_string",
"baz": "123",
"qux": date(2024, 3, 7),
},
{
"foo": True,
"bar": "bar_string",
"baz": "124",
"qux": date(2024, 3, 7),
},
],
schema=WRONG_SCHEMA,
)
)

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
| ✅ │ 2: bar: optional string │ 2: bar: optional string │
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
tbl.add_files(file_paths=[file_path])


@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"
Expand Down Expand Up @@ -518,7 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
assert table_schema == arrow_schema_large


def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

nanoseconds_schema = pa.schema([
Expand Down Expand Up @@ -549,25 +601,18 @@ def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_versi
partition_spec=PartitionSpec(),
)

file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)]
file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
writer.write_table(arrow_table)
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
writer.write_table(arrow_table)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)

assert tbl.scan().to_arrow() == pa.concat_tables(
[
arrow_table.cast(
pa.schema([
("quux", pa.timestamp("us", tz="UTC")),
]),
safe=False,
)
]
* 5
)
with pytest.raises(
TypeError,
match=re.escape(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
),
):
tbl.add_files(file_paths=[file_path])
91 changes: 91 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
_check_schema_compatible,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
Expand Down Expand Up @@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
assert len(list(bin_packed)) == 5


def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.decimal128(18, 6), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = r"""Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴─────────────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ Missing │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
))

expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")


def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
Expand Down
Loading