Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Make Parquet tests more robust and expose Parquet logic #46944

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 103 additions & 82 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class _SampleInfo:

# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
# raw pyarrow file fragment causes S3 network calls.
class _SerializedFragment:
class SerializedFragment:
def __init__(self, frag: "ParquetFileFragment"):
self._data = cloudpickle.dumps(
(frag.format, frag.path, frag.filesystem, frag.partition_expression)
Expand All @@ -114,12 +114,12 @@ def deserialize(self) -> "ParquetFileFragment":

# Visible for test mocking.
def _deserialize_fragments(
serialized_fragments: List[_SerializedFragment],
serialized_fragments: List[SerializedFragment],
) -> List["pyarrow._dataset.ParquetFileFragment"]:
return [p.deserialize() for p in serialized_fragments]


def _check_for_legacy_tensor_type(schema):
def check_for_legacy_tensor_type(schema):
"""Check for the legacy tensor extension type and raise an error if found.

Ray Data uses an extension type to represent tensors in Arrow tables. Previously,
Expand Down Expand Up @@ -171,7 +171,6 @@ def __init__(
_check_pyarrow_version()

import pyarrow as pa
import pyarrow.parquet as pq

self._supports_distributed_reads = not _is_local_scheme(paths)
if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
Expand Down Expand Up @@ -211,38 +210,20 @@ def __init__(
filtered_paths = set(expanded_paths) - set(paths)
if filtered_paths:
logger.info(f"Filtered out {len(filtered_paths)} paths")
else:
if len(paths) == 1:
paths = paths[0]

if dataset_kwargs is None:
dataset_kwargs = {}

try:
# The `use_legacy_dataset` parameter is deprecated in Arrow 15.
if parse_version(_get_pyarrow_version()) >= parse_version("15.0.0"):
pq_ds = pq.ParquetDataset(
paths,
**dataset_kwargs,
filesystem=filesystem,
)
else:
pq_ds = pq.ParquetDataset(
paths,
**dataset_kwargs,
filesystem=filesystem,
use_legacy_dataset=False,
)
except OSError as e:
_handle_read_os_error(e, paths)
pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)

if schema is None:
schema = pq_ds.schema
if columns:
schema = pa.schema(
[schema.field(column) for column in columns], schema.metadata
)

_check_for_legacy_tensor_type(schema)
check_for_legacy_tensor_type(schema)

if _block_udf is not None:
# Try to infer dataset schema by passing dummy table through UDF.
Expand Down Expand Up @@ -289,7 +270,7 @@ def __init__(
# NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected
# network calls when `_ParquetDatasourceReader` is serialized. See
# `_SerializedFragment()` implementation for more details.
self._pq_fragments = [_SerializedFragment(p) for p in pq_ds.fragments]
self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments]
self._pq_paths = [p.path for p in pq_ds.fragments]
self._meta_provider = meta_provider
self._inferred_schema = inferred_schema
Expand All @@ -302,9 +283,15 @@ def __init__(
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()

sample_infos = self._sample_fragments()
self._encoding_ratio = _estimate_files_encoding_ratio(sample_infos)
self._default_read_batch_size_rows = _estimate_default_read_batch_size_rows(
sample_infos = sample_fragments(
self._pq_fragments,
to_batches_kwargs=to_batch_kwargs,
columns=columns,
schema=schema,
local_scheduling=self._local_scheduling,
)
self._encoding_ratio = estimate_files_encoding_ratio(sample_infos)
self._default_read_batch_size_rows = estimate_default_read_batch_size_rows(
sample_infos
)

Expand Down Expand Up @@ -381,7 +368,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
)
read_tasks.append(
ReadTask(
lambda f=fragments: _read_fragments(
lambda f=fragments: read_fragments(
block_udf,
to_batches_kwargs,
default_read_batch_size_rows,
Expand All @@ -396,53 +383,6 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:

return read_tasks

def _sample_fragments(self) -> List[_SampleInfo]:
# Sample a few rows from Parquet files to estimate the encoding ratio.
# Launch tasks to sample multiple files remotely in parallel.
# Evenly distributed to sample N rows in i-th row group in i-th file.
# TODO(ekl/cheng) take into account column pruning.
num_files = len(self._pq_fragments)
num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
min_num_samples = min(
PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files
)
max_num_samples = min(
PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files
)
num_samples = max(min(num_samples, max_num_samples), min_num_samples)

# Evenly distributed to choose which file to sample, to avoid biased prediction
# if data is skewed.
file_samples = [
self._pq_fragments[idx]
for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist()
]

sample_fragment = cached_remote_fn(_sample_fragment)
futures = []
scheduling = self._local_scheduling or "SPREAD"
for sample in file_samples:
# Sample the first rows batch in i-th file.
# Use SPREAD scheduling strategy to avoid packing many sampling tasks on
# same machine to cause OOM issue, as sampling can be memory-intensive.
futures.append(
sample_fragment.options(
scheduling_strategy=scheduling,
# Retry in case of transient errors during sampling.
retry_exceptions=[OSError],
).remote(
self._to_batches_kwargs,
self._columns,
self._schema,
sample,
)
)
sample_bar = ProgressBar("Parquet Files Sample", len(futures), unit="file")
sample_infos = sample_bar.fetch_until_complete(futures)
sample_bar.close()

return sample_infos

def get_name(self):
"""Return a human-readable name for this datasource.

Expand All @@ -455,13 +395,13 @@ def supports_distributed_reads(self) -> bool:
return self._supports_distributed_reads


def _read_fragments(
def read_fragments(
block_udf,
to_batches_kwargs,
default_read_batch_size_rows,
columns,
schema,
serialized_fragments: List[_SerializedFragment],
serialized_fragments: List[SerializedFragment],
include_paths: bool,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
Expand Down Expand Up @@ -527,7 +467,7 @@ def _sample_fragment(
to_batches_kwargs,
columns,
schema,
file_fragment: _SerializedFragment,
file_fragment: SerializedFragment,
) -> _SampleInfo:
# Sample the first rows batch from file fragment `serialized_fragment`.
fragment = _deserialize_fragments_with_retry([file_fragment])[0]
Expand Down Expand Up @@ -570,7 +510,7 @@ def _sample_fragment(
return sample_data


def _estimate_files_encoding_ratio(sample_infos: List[_SampleInfo]) -> float:
def estimate_files_encoding_ratio(sample_infos: List[_SampleInfo]) -> float:
"""Return an estimate of the Parquet files encoding ratio.

To avoid OOMs, it is safer to return an over-estimate than an underestimate.
Expand All @@ -594,7 +534,7 @@ def compute_encoding_ratio(sample_info: _SampleInfo) -> float:
return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)


def _estimate_default_read_batch_size_rows(sample_infos: List[_SampleInfo]) -> int:
def estimate_default_read_batch_size_rows(sample_infos: List[_SampleInfo]) -> int:
def compute_batch_size_rows(sample_info: _SampleInfo) -> int:
if sample_info.actual_bytes_per_row is None:
return PARQUET_READER_ROW_BATCH_SIZE
Expand All @@ -612,3 +552,84 @@ def compute_batch_size_rows(sample_info: _SampleInfo) -> int:
)

return np.mean(list(map(compute_batch_size_rows, sample_infos)))


def get_parquet_dataset(paths, filesystem, dataset_kwargs):
import pyarrow.parquet as pq

# If you pass a list containing a single directory path to `ParquetDataset`, PyArrow
# errors with 'IsADirectoryError: Path ... points to a directory, but only file
# paths are supported'. To avoid this, we pass the directory path directly.
if len(paths) == 1:
paths = paths[0]

try:
# The `use_legacy_dataset` parameter is deprecated in Arrow 15.
if parse_version(_get_pyarrow_version()) >= parse_version("15.0.0"):
dataset = pq.ParquetDataset(
paths,
**dataset_kwargs,
filesystem=filesystem,
)
else:
dataset = pq.ParquetDataset(
paths,
**dataset_kwargs,
filesystem=filesystem,
use_legacy_dataset=False,
)
except OSError as e:
_handle_read_os_error(e, paths)

return dataset


def sample_fragments(
serialized_fragments,
*,
to_batches_kwargs,
columns,
schema,
local_scheduling=None,
) -> List[_SampleInfo]:
# Sample a few rows from Parquet files to estimate the encoding ratio.
# Launch tasks to sample multiple files remotely in parallel.
# Evenly distributed to sample N rows in i-th row group in i-th file.
# TODO(ekl/cheng) take into account column pruning.
num_files = len(serialized_fragments)
num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files)
max_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files)
num_samples = max(min(num_samples, max_num_samples), min_num_samples)

# Evenly distributed to choose which file to sample, to avoid biased prediction
# if data is skewed.
file_samples = [
serialized_fragments[idx]
for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist()
]

sample_fragment = cached_remote_fn(_sample_fragment)
futures = []
scheduling = local_scheduling or "SPREAD"
for sample in file_samples:
# Sample the first rows batch in i-th file.
# Use SPREAD scheduling strategy to avoid packing many sampling tasks on
# same machine to cause OOM issue, as sampling can be memory-intensive.
futures.append(
sample_fragment.options(
scheduling_strategy=scheduling,
# Retry in case of transient errors during sampling.
retry_exceptions=[OSError],
).remote(
to_batches_kwargs,
columns,
schema,
sample,
)
)
sample_bar = ProgressBar("Parquet Files Sample", len(futures), unit="file")
sample_infos = sample_bar.fetch_until_complete(futures)
sample_bar.close()

return sample_infos
8 changes: 4 additions & 4 deletions python/ray/data/datasource/parquet_meta_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
import pyarrow

from ray.data._internal.datasource.parquet_datasource import _SerializedFragment
from ray.data._internal.datasource.parquet_datasource import SerializedFragment


FRAGMENTS_PER_META_FETCH = 6
Expand Down Expand Up @@ -131,11 +131,11 @@ def prefetch_file_metadata(
must be returned in the same order as all input file fragments, such
that `metadata[i]` always contains the metadata for `fragments[i]`.
"""
from ray.data._internal.datasource.parquet_datasource import _SerializedFragment
from ray.data._internal.datasource.parquet_datasource import SerializedFragment

if len(fragments) > PARALLELIZE_META_FETCH_THRESHOLD:
# Wrap Parquet fragments in serialization workaround.
fragments = [_SerializedFragment(fragment) for fragment in fragments]
fragments = [SerializedFragment(fragment) for fragment in fragments]
# Fetch Parquet metadata in parallel using Ray tasks.

def fetch_func(fragments):
Expand All @@ -162,7 +162,7 @@ def fetch_func(fragments):


def _fetch_metadata_serialization_wrapper(
fragments: List["_SerializedFragment"],
fragments: List["SerializedFragment"],
retry_match: Optional[List[str]],
retry_max_attempts: int,
retry_max_interval: int,
Expand Down
12 changes: 8 additions & 4 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ def test_fsspec_filesystem(ray_start_regular_shared, tmp_path):
ds._set_uuid("data")
ds.write_parquet(out_path)

ds_df1 = pd.read_parquet(os.path.join(out_path, "data_000000_000000.parquet"))
ds_df2 = pd.read_parquet(os.path.join(out_path, "data_000001_000000.parquet"))
ds_df = pd.concat([ds_df1, ds_df2])
df = pd.concat([df1, df2])
ds_dfs = []
# `write_parquet` writes an unspecified number of files.
for path in os.listdir(out_path):
assert path.startswith("data_") and path.endswith(".parquet")
ds_dfs.append(pd.read_parquet(os.path.join(out_path, path)))

ds_df = pd.concat(ds_dfs).reset_index(drop=True)
df = pd.concat([df1, df2]).reset_index(drop=True)
assert ds_df.equals(df)


Expand Down
Loading