diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index a3abbb00bc7b..cc87d420dc86 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -613,6 +613,54 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): ) return Dataset(plan, self._epoch, self._lazy) + def random_sample( + self, fraction: float, *, seed: Optional[int] = None + ) -> "Dataset[T]": + """Randomly samples a fraction of the elements of this dataset. + + Note that the exact number of elements returned is not guaranteed, + and that the number of elements being returned is roughly fraction * total_rows. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) # doctest: +SKIP + >>> ds.random_sample(0.1) # doctest: +SKIP + >>> ds.random_sample(0.2, seed=12345) # doctest: +SKIP + + Args: + fraction: The fraction of elements to sample. + seed: Seeds the python random pRNG generator. + + Returns: + Returns a Dataset containing the sampled elements. + """ + import random + import pyarrow as pa + import pandas as pd + + if self.num_blocks() == 0: + raise ValueError("Cannot sample from an empty dataset.") + + if fraction < 0 or fraction > 1: + raise ValueError("Fraction must be between 0 and 1.") + + if seed: + random.seed(seed) + + def process_batch(batch): + if isinstance(batch, list): + return [row for row in batch if random.random() <= fraction] + if isinstance(batch, pa.Table): + # Lets the item pass if weight generated for that item <= fraction + return batch.filter( + pa.array(random.random() <= fraction for _ in range(len(batch))) + ) + if isinstance(batch, pd.DataFrame): + return batch.sample(frac=fraction) + raise ValueError(f"Unsupported batch type: {type(batch)}") + + return self.map_batches(process_batch) + def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None ) -> List["Dataset[T]"]: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index a942becdf92f..b5fc9f67a0fc 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3425,6 +3425,48 @@ def test_column_name_type_check(ray_start_regular_shared): ray.data.from_pandas(df) +def test_random_sample(ray_start_regular_shared): + import math + + def ensure_sample_size_close(dataset, sample_percent=0.5): + r1 = ds.random_sample(sample_percent) + assert math.isclose( + r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2 + ) + + ds = ray.data.range(10, parallelism=2) + ensure_sample_size_close(ds) + + ds = ray.data.range_arrow(10, parallelism=2) + ensure_sample_size_close(ds) + + ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2)) + ensure_sample_size_close(ds) + + # imbalanced datasets + ds1 = ray.data.range(1, parallelism=1) + ds2 = ray.data.range(2, parallelism=1) + ds3 = ray.data.range(3, parallelism=1) + # noinspection PyTypeChecker + ds = ds1.union(ds2).union(ds3) + ensure_sample_size_close(ds) + # Small datasets + ds1 = ray.data.range(5, parallelism=5) + ensure_sample_size_close(ds1) + + +def test_random_sample_checks(ray_start_regular_shared): + with pytest.raises(ValueError): + # Cannot sample -1 + ray.data.range(1).random_sample(-1) + with pytest.raises(ValueError): + # Cannot sample from empty dataset + ray.data.range(0).random_sample(0.2) + with pytest.raises(ValueError): + # Cannot sample fraction > 1 + ray.data.range(1).random_sample(10) + + @pytest.mark.parametrize("pipelined", [False, True]) @pytest.mark.parametrize("use_push_based_shuffle", [False, True]) def test_random_shuffle(shutdown_only, pipelined, use_push_based_shuffle):