diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 5897881fcc..5eec487c67 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -326,9 +326,8 @@ tbl.add_files(file_paths=file_paths) !!! note "Name Mapping" Because `add_files` uses existing files without writing new parquet files that are aware of the Iceberg's schema, it requires the Iceberg's table to have a [Name Mapping](https://iceberg.apache.org/spec/?h=name+mapping#name-mapping-serialization) (The Name mapping maps the field names within the parquet files to the Iceberg field IDs). Hence, `add_files` requires that there are no field IDs in the parquet file's metadata, and creates a new Name Mapping based on the table's current schema if the table doesn't already have one. - - - +!!! note "Partitions" + `add_files` only requires the client to read the existing parquet files' metadata footer to infer the partition value of each file. This implementation also supports adding files to Iceberg tables with partition transforms like `MonthTransform`, and `TruncateTransform` which preserve the order of the values after the transformation (Any Transform that has the `preserves_order` property set to True is supported). Please note that if the column statistics of the `PartitionField`'s source column are not present in the parquet metadata, the partition value is inferred as `None`. !!! warning "Maintenance Operations" Because `add_files` commits the existing parquet files to the Iceberg Table as any other data file, destructive maintenance operations like expiring snapshots will remove them. diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 31d846f6f0..72de14880a 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -111,6 +111,7 @@ DataFileContent, FileFormat, ) +from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value from pyiceberg.schema import ( PartnerAccessor, PreOrderSchemaVisitor, @@ -124,7 +125,7 @@ visit, visit_with_partner, ) -from pyiceberg.table import AddFileTask, PropertyUtil, TableProperties, WriteTask +from pyiceberg.table import PropertyUtil, TableProperties, WriteTask from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.name_mapping import NameMapping from pyiceberg.transforms import TruncateTransform @@ -1594,29 +1595,88 @@ def parquet_path_to_id_mapping( return result -def fill_parquet_file_metadata( - data_file: DataFile, +@dataclass(frozen=True) +class DataFileStatistics: + record_count: int + column_sizes: Dict[int, int] + value_counts: Dict[int, int] + null_value_counts: Dict[int, int] + nan_value_counts: Dict[int, int] + column_aggregates: Dict[int, StatsAggregator] + split_offsets: List[int] + + def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any: + if partition_field.source_id not in self.column_aggregates: + return None + + if not partition_field.transform.preserves_order: + raise ValueError( + f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: {partition_field.name} with transform {partition_field.transform}" + ) + + lower_value = partition_record_value( + partition_field=partition_field, + value=self.column_aggregates[partition_field.source_id].current_min, + schema=schema, + ) + upper_value = partition_record_value( + partition_field=partition_field, + value=self.column_aggregates[partition_field.source_id].current_max, + schema=schema, + ) + if lower_value != upper_value: + raise ValueError( + f"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}" + ) + return lower_value + + def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record: + return Record(**{field.name: self._partition_value(field, schema) for field in partition_spec.fields}) + + def to_serialized_dict(self) -> Dict[str, Any]: + lower_bounds = {} + upper_bounds = {} + + for k, agg in self.column_aggregates.items(): + _min = agg.min_as_bytes() + if _min is not None: + lower_bounds[k] = _min + _max = agg.max_as_bytes() + if _max is not None: + upper_bounds[k] = _max + return { + "record_count": self.record_count, + "column_sizes": self.column_sizes, + "value_counts": self.value_counts, + "null_value_counts": self.null_value_counts, + "nan_value_counts": self.nan_value_counts, + "lower_bounds": lower_bounds, + "upper_bounds": upper_bounds, + "split_offsets": self.split_offsets, + } + + +def data_file_statistics_from_parquet_metadata( parquet_metadata: pq.FileMetaData, stats_columns: Dict[int, StatisticsCollector], parquet_column_mapping: Dict[str, int], -) -> None: +) -> DataFileStatistics: """ - Compute and fill the following fields of the DataFile object. + Compute and return DataFileStatistics that includes the following. - - file_format + - record_count - column_sizes - value_counts - null_value_counts - nan_value_counts - - lower_bounds - - upper_bounds + - column_aggregates - split_offsets Args: - data_file (DataFile): A DataFile object representing the Parquet file for which metadata is to be filled. parquet_metadata (pyarrow.parquet.FileMetaData): A pyarrow metadata object. stats_columns (Dict[int, StatisticsCollector]): The statistics gathering plan. It is required to set the mode for column metrics collection + parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID """ if parquet_metadata.num_columns != len(stats_columns): raise ValueError( @@ -1695,30 +1755,19 @@ def fill_parquet_file_metadata( split_offsets.sort() - lower_bounds = {} - upper_bounds = {} - - for k, agg in col_aggs.items(): - _min = agg.min_as_bytes() - if _min is not None: - lower_bounds[k] = _min - _max = agg.max_as_bytes() - if _max is not None: - upper_bounds[k] = _max - for field_id in invalidate_col: - del lower_bounds[field_id] - del upper_bounds[field_id] + del col_aggs[field_id] del null_value_counts[field_id] - data_file.record_count = parquet_metadata.num_rows - data_file.column_sizes = column_sizes - data_file.value_counts = value_counts - data_file.null_value_counts = null_value_counts - data_file.nan_value_counts = nan_value_counts - data_file.lower_bounds = lower_bounds - data_file.upper_bounds = upper_bounds - data_file.split_offsets = split_offsets + return DataFileStatistics( + record_count=parquet_metadata.num_rows, + column_sizes=column_sizes, + value_counts=value_counts, + null_value_counts=null_value_counts, + nan_value_counts=nan_value_counts, + column_aggregates=col_aggs, + split_offsets=split_offsets, + ) def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: @@ -1747,6 +1796,11 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: writer.write_table(task.df, row_group_size=row_group_size) + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=writer.writer.metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) data_file = DataFile( content=DataFileContent.DATA, file_path=file_path, @@ -1761,47 +1815,41 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT spec_id=table_metadata.default_spec_id, equality_ids=None, key_metadata=None, + **statistics.to_serialized_dict(), ) - fill_parquet_file_metadata( - data_file=data_file, - parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) return iter([data_file]) -def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[AddFileTask]) -> Iterator[DataFile]: - for task in tasks: - input_file = io.new_input(task.file_path) +def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: + for file_path in file_paths: + input_file = io.new_input(file_path) with input_file.open() as input_stream: parquet_metadata = pq.read_metadata(input_stream) if visit_pyarrow(parquet_metadata.schema.to_arrow_schema(), _HasIds()): raise NotImplementedError( - f"Cannot add file {task.file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" + f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" ) - schema = table_metadata.schema() + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=parquet_metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) data_file = DataFile( content=DataFileContent.DATA, - file_path=task.file_path, + file_path=file_path, file_format=FileFormat.PARQUET, - partition=task.partition_field_value, - record_count=parquet_metadata.num_rows, + partition=statistics.partition(table_metadata.spec(), table_metadata.schema()), file_size_in_bytes=len(input_file), sort_order_id=None, spec_id=table_metadata.default_spec_id, equality_ids=None, key_metadata=None, + **statistics.to_serialized_dict(), ) - fill_parquet_file_metadata( - data_file=data_file, - parquet_metadata=parquet_metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) + yield data_file diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index a6692b325e..16f158828d 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -388,16 +388,33 @@ def partition(self) -> Record: # partition key transformed with iceberg interna if len(partition_fields) != 1: raise ValueError("partition_fields must contain exactly one field.") partition_field = partition_fields[0] - iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type - iceberg_typed_value = _to_partition_representation(iceberg_type, raw_partition_field_value.value) - transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value) - iceberg_typed_key_values[partition_field.name] = transformed_value + iceberg_typed_key_values[partition_field.name] = partition_record_value( + partition_field=partition_field, + value=raw_partition_field_value.value, + schema=self.schema, + ) return Record(**iceberg_typed_key_values) def to_path(self) -> str: return self.partition_spec.partition_to_path(self.partition, self.schema) +def partition_record_value(partition_field: PartitionField, value: Any, schema: Schema) -> Any: + """ + Return the Partition Record representation of the value. + + The value is first converted to internal partition representation. + For example, UUID is converted to bytes[16], DateType to days since epoch, etc. + + Then the corresponding PartitionField's transform is applied to return + the final partition record value. + """ + iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type + iceberg_typed_value = _to_partition_representation(iceberg_type, value) + transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value) + return transformed_value + + @singledispatch def _to_partition_representation(type: IcebergType, value: Any) -> Any: return TypeError(f"Unsupported partition field type: {type}") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 517e6c86df..c0db85f854 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -33,7 +33,6 @@ Dict, Generic, Iterable, - Iterator, List, Literal, Optional, @@ -1170,9 +1169,6 @@ def add_files(self, file_paths: List[str]) -> None: Raises: FileNotFoundError: If the file does not exist. """ - if len(self.spec().fields) > 0: - raise ValueError("Cannot add files to partitioned tables") - with self.transaction() as tx: if self.name_mapping() is None: tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self.schema().name_mapping.model_dump_json()}) @@ -2515,17 +2511,6 @@ def _dataframe_to_data_files( yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)])) -def add_file_tasks_from_file_paths(file_paths: List[str], table_metadata: TableMetadata) -> Iterator[AddFileTask]: - if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: - raise ValueError("Cannot add files to partitioned tables") - - for file_path in file_paths: - yield AddFileTask( - file_path=file_path, - partition_field_value=Record(), - ) - - def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]: """Convert a list files into DataFiles. @@ -2534,8 +2519,7 @@ def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List """ from pyiceberg.io.pyarrow import parquet_files_to_data_files - tasks = add_file_tasks_from_file_paths(file_paths, table_metadata) - yield from parquet_files_to_data_files(io=io, table_metadata=table_metadata, tasks=tasks) + yield from parquet_files_to_data_files(io=io, table_metadata=table_metadata, file_paths=iter(file_paths)) class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]): diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 2066e178cd..7c17618280 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -26,9 +26,10 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.partitioning import PartitionSpec +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table +from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform from pyiceberg.types import ( BooleanType, DateType, @@ -103,25 +104,31 @@ ) -def _create_table(session_catalog: Catalog, identifier: str, partition_spec: Optional[PartitionSpec] = None) -> Table: +def _create_table( + session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None +) -> Table: try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass tbl = session_catalog.create_table( - identifier=identifier, schema=TABLE_SCHEMA, partition_spec=partition_spec if partition_spec else PartitionSpec() + identifier=identifier, + schema=TABLE_SCHEMA, + properties={"format-version": str(format_version)}, + partition_spec=partition_spec if partition_spec else PartitionSpec(), ) return tbl @pytest.mark.integration -def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog) -> None: - identifier = "default.unpartitioned_table" - tbl = _create_table(session_catalog, identifier) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.unpartitioned_table_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) - file_paths = [f"s3://warehouse/default/unpartitioned/test-{i}.parquet" for i in range(5)] + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] # write parquet files for file_path in file_paths: fo = tbl.io.new_output(file_path) @@ -153,11 +160,14 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: @pytest.mark.integration -def test_add_files_to_unpartitioned_table_raises_file_not_found(spark: SparkSession, session_catalog: Catalog) -> None: - identifier = "default.unpartitioned_raises_not_found" - tbl = _create_table(session_catalog, identifier) - - file_paths = [f"s3://warehouse/default/unpartitioned_raises_not_found/test-{i}.parquet" for i in range(5)] +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_unpartitioned_table_raises_file_not_found( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_raises_not_found_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_raises_not_found/v{format_version}/test-{i}.parquet" for i in range(5)] # write parquet files for file_path in file_paths: fo = tbl.io.new_output(file_path) @@ -171,11 +181,14 @@ def test_add_files_to_unpartitioned_table_raises_file_not_found(spark: SparkSess @pytest.mark.integration -def test_add_files_to_unpartitioned_table_raises_has_field_ids(spark: SparkSession, session_catalog: Catalog) -> None: - identifier = "default.unpartitioned_raises_field_ids" - tbl = _create_table(session_catalog, identifier) - - file_paths = [f"s3://warehouse/default/unpartitioned_raises_field_ids/test-{i}.parquet" for i in range(5)] +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_unpartitioned_table_raises_has_field_ids( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_raises_field_ids_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_raises_field_ids/v{format_version}/test-{i}.parquet" for i in range(5)] # write parquet files for file_path in file_paths: fo = tbl.io.new_output(file_path) @@ -189,11 +202,14 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids(spark: SparkSessi @pytest.mark.integration -def test_add_files_to_unpartitioned_table_with_schema_updates(spark: SparkSession, session_catalog: Catalog) -> None: - identifier = "default.unpartitioned_table_2" - tbl = _create_table(session_catalog, identifier) - - file_paths = [f"s3://warehouse/default/unpartitioned_2/test-{i}.parquet" for i in range(5)] +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_unpartitioned_table_with_schema_updates( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_table_schema_updates_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_schema_updates/v{format_version}/test-{i}.parquet" for i in range(5)] # write parquet files for file_path in file_paths: fo = tbl.io.new_output(file_path) @@ -211,7 +227,7 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(spark: SparkSessio update.add_column("quux", IntegerType()) update.delete_column("bar") - file_path = "s3://warehouse/default/unpartitioned_2/test-6.parquet" + file_path = f"s3://warehouse/default/unpartitioned_schema_updates/v{format_version}/test-6.parquet" # write parquet files fo = tbl.io.new_output(file_path) with fo.create(overwrite=True) as fos: @@ -238,3 +254,164 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(spark: SparkSessio for col in df.columns: value_count = 1 if col == "quux" else 6 assert df.filter(df[col].isNotNull()).count() == value_count, f"Expected {value_count} rows to be non-null" + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.partitioned_table_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"), + PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="qux_month"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + date_iter = iter([date(2024, 3, 7), date(2024, 3, 8), date(2024, 3, 16), date(2024, 3, 18), date(2024, 3, 19)]) + + file_paths = [f"s3://warehouse/default/partitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": next(date_iter), + } + ], + schema=ARROW_SCHEMA, + ) + ) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [5] + assert [row.existing_data_files_count for row in rows] == [0] + assert [row.deleted_data_files_count for row in rows] == [0] + + df = spark.table(identifier) + assert df.count() == 5, "Expected 5 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null" + + partition_rows = spark.sql( + f""" + SELECT partition, record_count, file_count + FROM {identifier}.partitions + """ + ).collect() + + assert [row.record_count for row in partition_rows] == [5] + assert [row.file_count for row in partition_rows] == [5] + assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(123, 650)] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.partitioned_table_bucket_fails_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=BucketTransform(num_buckets=3), name="baz_bucket_3"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + int_iter = iter(range(5)) + + file_paths = [f"s3://warehouse/default/partitioned_table_bucket_fails/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": next(int_iter), + "qux": date(2024, 3, 7), + } + ], + schema=ARROW_SCHEMA, + ) + ) + + # add the parquet files as data files + with pytest.raises(ValueError) as exc_info: + tbl.add_files(file_paths=file_paths) + assert ( + "Cannot infer partition value from parquet metadata for a non-linear Partition Field: baz_bucket_3 with transform bucket[3]" + in str(exc_info.value) + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.partitioned_table_mismatch_fails_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + file_paths = [f"s3://warehouse/default/partitioned_table_mismatch_fails/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + }, + { + "foo": True, + "bar": "bar_string", + "baz": 124, + "qux": date(2024, 3, 7), + }, + ], + schema=ARROW_SCHEMA, + ) + ) + + # add the parquet files as data files + with pytest.raises(ValueError) as exc_info: + tbl.add_files(file_paths=file_paths) + assert ( + "Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: baz. lower_value=123, upper_value=124" + in str(exc_info.value) + ) diff --git a/tests/io/test_pyarrow_stats.py b/tests/io/test_pyarrow_stats.py index 01b844a43e..41f1432dbf 100644 --- a/tests/io/test_pyarrow_stats.py +++ b/tests/io/test_pyarrow_stats.py @@ -52,7 +52,7 @@ MetricsMode, PyArrowStatisticsCollector, compute_statistics_plan, - fill_parquet_file_metadata, + data_file_statistics_from_parquet_metadata, match_metrics_mode, parquet_path_to_id_mapping, schema_to_pyarrow, @@ -185,13 +185,12 @@ def test_record_count() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert datafile.record_count == 4 @@ -199,13 +198,12 @@ def test_value_counts() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 7 assert datafile.value_counts[1] == 4 @@ -221,13 +219,12 @@ def test_column_sizes() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.column_sizes) == 7 # these values are an artifact of how the write_table encodes the columns @@ -242,13 +239,12 @@ def test_null_and_nan_counts() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.null_value_counts) == 7 assert datafile.null_value_counts[1] == 1 @@ -270,13 +266,12 @@ def test_bounds() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.lower_bounds) == 2 assert datafile.lower_bounds[1].decode() == "aaaaaaaaaaaaaaaa" @@ -314,14 +309,13 @@ def test_metrics_mode_none() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "none" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 0 assert len(datafile.null_value_counts) == 0 @@ -334,14 +328,13 @@ def test_metrics_mode_counts() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "counts" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 7 assert len(datafile.null_value_counts) == 7 @@ -354,14 +347,13 @@ def test_metrics_mode_full() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "full" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 7 assert len(datafile.null_value_counts) == 7 @@ -380,14 +372,13 @@ def test_metrics_mode_non_default_trunc() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 7 assert len(datafile.null_value_counts) == 7 @@ -406,15 +397,14 @@ def test_column_metrics_mode() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" table_metadata.properties["write.metadata.metrics.column.strings"] = "none" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 6 assert len(datafile.null_value_counts) == 6 @@ -508,14 +498,13 @@ def test_metrics_primitive_types() -> None: metadata, table_metadata = construct_test_table_primitive_types() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 12 assert len(datafile.null_value_counts) == 12 @@ -607,14 +596,13 @@ def test_metrics_invalid_upper_bound() -> None: metadata, table_metadata = construct_test_table_invalid_upper_bound() schema = get_current_schema(table_metadata) - datafile = DataFile() table_metadata.properties["write.metadata.metrics.default"] = "truncate(2)" - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert len(datafile.value_counts) == 4 assert len(datafile.null_value_counts) == 4 @@ -635,13 +623,12 @@ def test_offsets() -> None: metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + datafile = DataFile(**statistics.to_serialized_dict()) assert datafile.split_offsets is not None assert len(datafile.split_offsets) == 1