From d86624502b4e1f0acb4d68957fca1188947d971f Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 14 Apr 2023 12:19:13 -0700 Subject: [PATCH] [Dataset] Validate sort key in `Sort` LogicalOperator (#34282) As a followup of #32133, we should validate key with block.py:_validate_key_fn(), in generate_sort_fn() before doing sort. Signed-off-by: Scott Lee --- python/ray/data/_internal/plan.py | 27 ++--------- python/ray/data/_internal/planner/sort.py | 9 +++- python/ray/data/_internal/stage_impl.py | 5 +- python/ray/data/_internal/util.py | 31 ++++++++++++ python/ray/data/aggregate.py | 2 +- python/ray/data/block.py | 9 ++-- python/ray/data/dataset.py | 2 +- .../data/tests/test_execution_optimizer.py | 47 +++++++++++++++++++ 8 files changed, 99 insertions(+), 33 deletions(-) diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 40474b68f891..cdb1bc6ec7d9 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -16,10 +16,10 @@ ) import ray +from ray.data._internal.util import unify_block_metadata_schema from ray.data.block import BlockMetadata from ray.data._internal.util import capitalize from ray.types import ObjectRef -from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ( UDF, @@ -419,29 +419,10 @@ def _get_unified_blocks_schema( blocks.ensure_metadata_for_first_block() metadata = blocks.get_metadata(fetch_if_missing=False) - # Some blocks could be empty, in which case we cannot get their schema. - # TODO(ekl) validate schema is the same across different blocks. - # First check if there are blocks with computed schemas, then unify - # valid schemas from all such blocks. - schemas_to_unify = [] - for m in metadata: - if m.schema is not None and (m.num_rows is None or m.num_rows > 0): - schemas_to_unify.append(m.schema) - if schemas_to_unify: - # Check valid pyarrow installation before attempting schema unification - try: - import pyarrow as pa - except ImportError: - pa = None - # If the result contains PyArrow schemas, unify them - if pa is not None and any( - isinstance(s, pa.Schema) for s in schemas_to_unify - ): - return unify_schemas(schemas_to_unify) - # Otherwise, if the resulting schemas are simple types (e.g. int), - # return the first schema. - return schemas_to_unify[0] + unified_schema = unify_block_metadata_schema(metadata) + if unified_schema is not None: + return unified_schema if not fetch_if_missing: return None # Synchronously fetch the schema. diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index 75bf22372032..955aac05c6d2 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -14,6 +14,8 @@ ) from ray.data._internal.planner.exchange.sort_task_spec import SortKeyT, SortTaskSpec from ray.data._internal.stats import StatsDict +from ray.data._internal.util import unify_block_metadata_schema +from ray.data.block import _validate_key_fn from ray.data.context import DataContext @@ -22,7 +24,6 @@ def generate_sort_fn( descending: bool, ) -> AllToAllTransformFn: """Generate function to sort blocks by the specified key column or key function.""" - # TODO: validate key with block._validate_key_fn. def fn( key: SortKeyT, @@ -31,11 +32,15 @@ def fn( ctx: TaskContext, ) -> Tuple[List[RefBundle], StatsDict]: blocks = [] + metadata = [] for ref_bundle in refs: - for block, _ in ref_bundle.blocks: + for block, block_metadata in ref_bundle.blocks: blocks.append(block) + metadata.append(block_metadata) if len(blocks) == 0: return (blocks, {}) + unified_schema = unify_block_metadata_schema(metadata) + _validate_key_fn(unified_schema, key) if isinstance(key, str): key = [(key, "descending" if descending else "ascending")] diff --git a/python/ray/data/_internal/stage_impl.py b/python/ray/data/_internal/stage_impl.py index 01ee6345e177..4a89454846c8 100644 --- a/python/ray/data/_internal/stage_impl.py +++ b/python/ray/data/_internal/stage_impl.py @@ -328,13 +328,14 @@ def do_sort( block_list.clear() else: blocks = block_list + schema = ds.schema(fetch_if_missing=True) if isinstance(key, list): if not key: raise ValueError("`key` must be a list of non-zero length") for subkey in key: - _validate_key_fn(ds, subkey) + _validate_key_fn(schema, subkey) else: - _validate_key_fn(ds, key) + _validate_key_fn(schema, key) return sort_impl(blocks, clear_input_blocks, key, descending, ctx) super().__init__( diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 0edd3ab43558..0a8a7e091387 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -9,6 +9,7 @@ import ray from ray.air.constants import TENSOR_COLUMN_NAME +from ray.data._internal.arrow_ops.transform_pyarrow import unify_schemas from ray.data.context import DataContext from ray._private.utils import _get_pyarrow_version @@ -462,3 +463,33 @@ def get_table_block_metadata( return BlockAccessor.for_block(table).get_metadata( input_files=None, exec_stats=stats.build() ) + + +def unify_block_metadata_schema( + metadata: List["BlockMetadata"], +) -> Optional[Union[type, "pyarrow.lib.Schema"]]: + """For the input list of BlockMetadata, return a unified schema of the + corresponding blocks. If the metadata have no valid schema, returns None. + """ + # Some blocks could be empty, in which case we cannot get their schema. + # TODO(ekl) validate schema is the same across different blocks. + + # First check if there are blocks with computed schemas, then unify + # valid schemas from all such blocks. + schemas_to_unify = [] + for m in metadata: + if m.schema is not None and (m.num_rows is None or m.num_rows > 0): + schemas_to_unify.append(m.schema) + if schemas_to_unify: + # Check valid pyarrow installation before attempting schema unification + try: + import pyarrow as pa + except ImportError: + pa = None + # If the result contains PyArrow schemas, unify them + if pa is not None and any(isinstance(s, pa.Schema) for s in schemas_to_unify): + return unify_schemas(schemas_to_unify) + # Otherwise, if the resulting schemas are simple types (e.g. int), + # return the first schema. + return schemas_to_unify[0] + return None diff --git a/python/ray/data/aggregate.py b/python/ray/data/aggregate.py index 0d8ee4117226..d93f086e8545 100644 --- a/python/ray/data/aggregate.py +++ b/python/ray/data/aggregate.py @@ -91,7 +91,7 @@ def _set_key_fn(self, on: KeyFn): self._key_fn = on def _validate(self, ds: "Datastream") -> None: - _validate_key_fn(ds, self._key_fn) + _validate_key_fn(ds.schema(fetch_if_missing=True), self._key_fn) @PublicAPI diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 31c898e8b0aa..f4b79bd17700 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -40,7 +40,6 @@ import pandas import pyarrow - from ray.data import Datastream from ray.data._internal.block_builder import BlockBuilder from ray.data.aggregate import AggregateFn @@ -58,9 +57,11 @@ KeyFn = Union[None, str, Callable[[T], Any]] -def _validate_key_fn(ds: "Datastream", key: KeyFn) -> None: - """Check the key function is valid on the given datastream.""" - schema = ds.schema(fetch_if_missing=True) +def _validate_key_fn( + schema: Optional[Union[type, "pyarrow.lib.Schema"]], + key: KeyFn, +) -> None: + """Check the key function is valid on the given schema.""" if schema is None: # Datastream is empty/cleared, validation not possible. return diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0beaad4739b1..04c615008a94 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1718,7 +1718,7 @@ def groupby(self, key: Optional[KeyFn]) -> "GroupedData[T]": # Always allow None since groupby interprets that as grouping all # records into a single global group. if key is not None: - _validate_key_fn(self, key) + _validate_key_fn(self.schema(fetch_if_missing=True), key) return GroupedData(self, key) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 3f728d48e2c4..df2560d5a050 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -643,6 +643,53 @@ def test_sort_e2e( # assert [d["one"] for d in r2] == list(reversed(range(100))) +def test_sort_validate_keys( + ray_start_regular_shared, + enable_optimizer, +): + ds = ray.data.range(10) + assert ds.sort().take_all() == list(range(10)) + + invalid_col_name = "invalid_column" + with pytest.raises( + ValueError, + match=f"String key '{invalid_col_name}' requires datastream format to be " + "'arrow' or 'pandas', was 'simple'", + ): + ds.sort(invalid_col_name).take_all() + + ds_named = ray.data.from_items( + [ + {"col1": 1, "col2": 2}, + {"col1": 3, "col2": 4}, + {"col1": 5, "col2": 6}, + {"col1": 7, "col2": 8}, + ] + ) + + ds_sorted_col1 = ds_named.sort("col1", descending=True) + r1 = ds_sorted_col1.select_columns(["col1"]).take_all() + r2 = ds_sorted_col1.select_columns(["col2"]).take_all() + assert [d["col1"] for d in r1] == [7, 5, 3, 1] + assert [d["col2"] for d in r2] == [8, 6, 4, 2] + + with pytest.raises( + ValueError, + match=f"The column '{invalid_col_name}' does not exist in the schema", + ): + ds_named.sort(invalid_col_name).take_all() + + def dummy_sort_fn(x): + return x + + with pytest.raises( + ValueError, + match=f"Callable key '{dummy_sort_fn}' requires datastream format to be " + "'simple'", + ): + ds_named.sort(dummy_sort_fn).take_all() + + def test_aggregate_operator(ray_start_regular_shared, enable_optimizer): planner = Planner() read_op = Read(ParquetDatasource())