From 9d996888741dd06f4138d20dd8f563fdbdd59401 Mon Sep 17 00:00:00 2001 From: Shiyan Xu Date: Wed, 27 Mar 2024 20:46:11 -0500 Subject: [PATCH] support passing colstats min max --- daft/hudi/hudi_scan.py | 40 ++++++++++++++++++------- daft/hudi/pyhudi/filegroup.py | 55 ++++++++++++++++++++++++++--------- daft/hudi/pyhudi/table.py | 46 +++++++++++++++++++++-------- daft/hudi/pyhudi/utils.py | 32 ++++++++++++++++---- 4 files changed, 130 insertions(+), 43 deletions(-) diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py index 4f9cb5156e..57e21d3879 100644 --- a/daft/hudi/hudi_scan.py +++ b/daft/hudi/hudi_scan.py @@ -18,7 +18,7 @@ ScanTask, StorageConfig, ) -from daft.hudi.pyhudi.table import HudiTable +from daft.hudi.pyhudi.table import HudiTable, HudiTableMetadata from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema @@ -57,8 +57,8 @@ def multiline_display(self) -> list[str]: def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: import pyarrow as pa - # TODO(Shiyan) integrate with metadata table to prune the files returned. - latest_files_metadata: pa.RecordBatch = self._table.latest_files_metadata() + hudi_table_metadata: HudiTableMetadata = self._table.latest_table_metadata() + files_metadata = hudi_table_metadata.files_metadata if len(self.partitioning_keys()) > 0 and pushdowns.partition_filters is None: logging.warning( @@ -68,21 +68,21 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: limit_files = pushdowns.limit is not None and pushdowns.filters is None and pushdowns.partition_filters is None rows_left = pushdowns.limit if pushdowns.limit is not None else 0 scan_tasks = [] - for task_idx in range(latest_files_metadata.num_rows): + for task_idx in range(files_metadata.num_rows): if limit_files and rows_left <= 0: break - path = latest_files_metadata["path"][task_idx].as_py() - record_count = latest_files_metadata["num_records"][task_idx].as_py() + path = files_metadata["path"][task_idx].as_py() + record_count = files_metadata["num_records"][task_idx].as_py() try: - size_bytes = latest_files_metadata["size_bytes"][task_idx].as_py() + size_bytes = files_metadata["size_bytes"][task_idx].as_py() except KeyError: size_bytes = None file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig()) if self._table.is_partitioned: - dtype = latest_files_metadata.schema.field("_hoodie_partition_path").type - part_values = latest_files_metadata["partition_values"][task_idx] + dtype = files_metadata.schema.field("partition_values").type + part_values = files_metadata["partition_values"][task_idx] arrays = {} for field_idx in range(dtype.num_fields): field_name = dtype.field(field_idx).name @@ -97,7 +97,25 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: partition_values = None # Populate scan task with column-wise stats. - # TODO(Shiyan): Add support for column stats + schema = self._table.schema + min_values = hudi_table_metadata.colstats_min_values + max_values = hudi_table_metadata.colstats_max_values + arrays = {} + for field_idx in range(len(schema)): + field_name = schema.field(field_idx).name + field_type = schema.field(field_idx).type + try: + arrow_arr = pa.array( + [min_values[field_name][task_idx], max_values[field_name][task_idx]], type=field_type + ) + except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError): + # pyarrow < 13.0.0 doesn't accept pyarrow scalars in the array constructor. + arrow_arr = pa.array( + [min_values[field_name][task_idx].as_py(), max_values[field_name][task_idx].as_py()], + type=field_type, + ) + arrays[field_name] = daft.Series.from_arrow(arrow_arr, field_name) + stats = daft.table.Table.from_pydict(arrays)._table st = ScanTask.catalog_scan_task( file=path, @@ -108,7 +126,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: size_bytes=size_bytes, pushdowns=pushdowns, partition_values=partition_values, - stats=None, + stats=stats, ) if st is None: continue diff --git a/daft/hudi/pyhudi/filegroup.py b/daft/hudi/pyhudi/filegroup.py index 673d0e2d07..1713dfa65c 100644 --- a/daft/hudi/pyhudi/filegroup.py +++ b/daft/hudi/pyhudi/filegroup.py @@ -6,40 +6,67 @@ from fsspec import AbstractFileSystem from sortedcontainers import SortedDict +from daft.hudi.pyhudi.utils import FsFileMetadata + @dataclass(init=False) class BaseFile: - def __init__(self, path: str, size: int, num_records: int, fs: AbstractFileSystem): - self.path = path - self.size = size - self.num_records = num_records - file_name = path.rsplit(fs.sep, 1)[-1] + def __init__(self, metadata: FsFileMetadata, fs: AbstractFileSystem): + self.metadata = metadata + file_name = metadata.path.rsplit(fs.sep, 1)[-1] self.file_name = file_name file_group_id, _, commit_time_ext = file_name.split("_") self.file_group_id = file_group_id self.commit_time = commit_time_ext.split(".")[0] + @property + def path(self) -> str: + return self.metadata.path + + @property + def size(self) -> int: + return self.metadata.size + + @property + def num_records(self) -> int: + return self.metadata.num_records + + @property + def schema(self) -> pa.Schema: + return self.metadata.schema + + @property + def min_values(self): + return self.metadata.min_values + + @property + def max_values(self): + return self.metadata.max_values + @dataclass class FileSlice: - METADATA_SCHEMA = pa.schema( + FILES_METADATA_SCHEMA = pa.schema( [ ("path", pa.string()), ("size", pa.uint32()), ("num_records", pa.uint32()), ("partition_path", pa.string()), - # TODO(Shiyan): support column stats ] ) file_group_id: str partition_path: str base_instant_time: str - base_file: BaseFile | None + base_file: BaseFile + + @property + def files_metadata(self): + return self.base_file.path, self.base_file.size, self.base_file.num_records, self.partition_path @property - def metadata(self): - return (self.base_file.path, self.base_file.size, self.base_file.num_records, self.partition_path) + def colstats_min_max(self): + return self.base_file.min_values, self.base_file.max_values @dataclass @@ -50,10 +77,10 @@ class FileGroup: def add_base_file(self, base_file: BaseFile): ct = base_file.commit_time - if ct not in self.file_slices: - self.file_slices[ct] = FileSlice(self.file_group_id, self.partition_path, ct, base_file=None) - - self.file_slices.get(ct).base_file = base_file + if ct in self.file_slices: + self.file_slices.get(ct).base_file = base_file + else: + self.file_slices[ct] = FileSlice(self.file_group_id, self.partition_path, ct, base_file) def get_latest_file_slice(self) -> FileSlice | None: if not self.file_slices: diff --git a/daft/hudi/pyhudi/table.py b/daft/hudi/pyhudi/table.py index fd5a3b9098..1b32299b72 100644 --- a/daft/hudi/pyhudi/table.py +++ b/daft/hudi/pyhudi/table.py @@ -10,7 +10,11 @@ from daft.hudi.pyhudi.filegroup import BaseFile, FileGroup, FileSlice from daft.hudi.pyhudi.timeline import Timeline -from daft.hudi.pyhudi.utils import get_full_file_paths, get_full_sub_dirs, get_leaf_dirs +from daft.hudi.pyhudi.utils import ( + list_full_file_paths, + list_full_sub_dirs, + list_leaf_dirs, +) # TODO(Shiyan): support base file in .orc BASE_FILE_EXTENSIONS = [".parquet"] @@ -28,11 +32,11 @@ def get_active_timeline(self) -> Timeline: return self.timeline def get_partition_paths(self, relative=True) -> list[str]: - first_level_full_partition_paths = get_full_sub_dirs(self.base_path, self.fs, excludes=[".hoodie"]) + first_level_full_partition_paths = list_full_sub_dirs(self.base_path, self.fs, excludes=[".hoodie"]) partition_paths = [] common_prefix_len = len(self.base_path) + 1 if relative else 0 for p in first_level_full_partition_paths: - partition_paths.extend(get_leaf_dirs(p, self.fs, common_prefix_len)) + partition_paths.extend(list_leaf_dirs(p, self.fs, common_prefix_len)) return partition_paths def get_full_partition_path(self, partition_path: str) -> str: @@ -40,10 +44,10 @@ def get_full_partition_path(self, partition_path: str) -> str: def get_file_groups(self, partition_path: str) -> list[FileGroup]: full_partition_path = self.get_full_partition_path(partition_path) - base_file_metadata = get_full_file_paths(full_partition_path, self.fs, includes=BASE_FILE_EXTENSIONS) + base_file_metadata = list_full_file_paths(full_partition_path, self.fs, includes=BASE_FILE_EXTENSIONS) fg_id_to_base_files = defaultdict(list) for metadata in base_file_metadata: - base_file = BaseFile(metadata.path, metadata.size, metadata.num_records, self.fs) + base_file = BaseFile(metadata, self.fs) fg_id_to_base_files[base_file.file_group_id].append(base_file) file_groups = [] for fg_id, base_files in fg_id_to_base_files.items(): @@ -103,6 +107,14 @@ def partition_fields(self) -> list[str]: return self._props["hoodie.table.partition.fields"] +@dataclass +class HudiTableMetadata: + + files_metadata: pa.RecordBatch + colstats_min_values: pa.RecordBatch + colstats_max_values: pa.RecordBatch + + @dataclass(init=False) class HudiTable: def __init__(self, table_uri: str, storage_options: dict[str, str] | None = None): @@ -110,14 +122,24 @@ def __init__(self, table_uri: str, storage_options: dict[str, str] | None = None self._meta_client = MetaClient(fs, table_uri, timeline=None) self._props = HudiTableProps(fs, table_uri) - def latest_files_metadata(self) -> pa.RecordBatch: - fs_view = FileSystemView(self._meta_client) - file_slices = fs_view.get_latest_file_slices() - metadata = [] + def latest_table_metadata(self) -> HudiTableMetadata: + file_slices = FileSystemView(self._meta_client).get_latest_file_slices() + files_metadata = [] + min_vals_arr = [] + max_vals_arr = [] for file_slice in file_slices: - metadata.append(file_slice.metadata) - metadata_arrays = [pa.array(column) for column in list(zip(*metadata))] - return pa.RecordBatch.from_arrays(metadata_arrays, schema=FileSlice.METADATA_SCHEMA) + files_metadata.append(file_slice.files_metadata) + min_vals, max_vals = file_slice.colstats_min_max + min_vals_arr.append(min_vals) + max_vals_arr.append(max_vals) + metadata_arrays = [pa.array(column) for column in list(zip(*files_metadata))] + min_value_arrays = [pa.array(column) for column in list(zip(*min_vals_arr))] + max_value_arrays = [pa.array(column) for column in list(zip(*max_vals_arr))] + return HudiTableMetadata( + pa.RecordBatch.from_arrays(metadata_arrays, schema=FileSlice.FILES_METADATA_SCHEMA), + pa.RecordBatch.from_arrays(min_value_arrays, schema=self.schema), + pa.RecordBatch.from_arrays(max_value_arrays, schema=self.schema), + ) @property def table_uri(self) -> str: diff --git a/daft/hudi/pyhudi/utils.py b/daft/hudi/pyhudi/utils.py index 862f8635bc..96f684f305 100644 --- a/daft/hudi/pyhudi/utils.py +++ b/daft/hudi/pyhudi/utils.py @@ -3,6 +3,7 @@ import os from dataclasses import dataclass +import pyarrow as pa import pyarrow.parquet as pq from fsspec import AbstractFileSystem @@ -18,9 +19,28 @@ def __init__(self, path: str): metadata = pq.read_metadata(path) self.size = metadata.serialized_size self.num_records = metadata.num_rows - - -def get_full_file_paths(path: str, fs: AbstractFileSystem, includes: list[str] | None) -> list[FsFileMetadata]: + self.schema, self.min_values, self.max_values = FsFileMetadata._extract_min_max(metadata) + + @staticmethod + def _extract_min_max(metadata: pq.FileMetaData): + arrow_schema = pa.schema(metadata.schema.to_arrow_schema()) + n_columns = len(arrow_schema) + min_vals = [None] * n_columns + max_vals = [None] * n_columns + num_rg = metadata.num_row_groups + for rg in range(num_rg): + row_group = metadata.row_group(rg) + for col in range(n_columns): + column = row_group.column(col) + if column.is_stats_set and column.statistics.has_min_max: + if min_vals[col] is None or column.statistics.min < min_vals[col]: + min_vals[col] = column.statistics.min + if max_vals[col] is None or column.statistics.max > max_vals[col]: + max_vals[col] = column.statistics.max + return arrow_schema, min_vals, max_vals + + +def list_full_file_paths(path: str, fs: AbstractFileSystem, includes: list[str] | None) -> list[FsFileMetadata]: sub_paths = fs.ls(path, detail=True) file_paths = [] for sub_path in sub_paths: @@ -31,7 +51,7 @@ def get_full_file_paths(path: str, fs: AbstractFileSystem, includes: list[str] | return file_paths -def get_full_sub_dirs(path: str, fs: AbstractFileSystem, excludes: list[str] | None) -> list[str]: +def list_full_sub_dirs(path: str, fs: AbstractFileSystem, excludes: list[str] | None) -> list[str]: sub_paths = fs.ls(path, detail=True) sub_dirs = [] for sub_path in sub_paths: @@ -42,13 +62,13 @@ def get_full_sub_dirs(path: str, fs: AbstractFileSystem, excludes: list[str] | N return sub_dirs -def get_leaf_dirs(path: str, fs: AbstractFileSystem, common_prefix_len=0) -> list[str]: +def list_leaf_dirs(path: str, fs: AbstractFileSystem, common_prefix_len=0) -> list[str]: sub_paths = fs.ls(path, detail=True) leaf_dirs = [] for sub_path in sub_paths: if sub_path["type"] == "directory": - leaf_dirs.extend(get_leaf_dirs(sub_path["name"], fs, common_prefix_len)) + leaf_dirs.extend(list_leaf_dirs(sub_path["name"], fs, common_prefix_len)) # leaf directory if len(leaf_dirs) == 0: