From f0e40b86e77b31f82eacb3a55d8d3bcc912b1feb Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 11:10:54 +0800 Subject: [PATCH 01/32] Build random_sample feature (#24449) --- python/ray/data/dataset.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 9307ef0f54ab..534f24c7cc9d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -613,6 +613,35 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): ) return Dataset(plan, self._epoch, self._lazy) + def random_sample(self, number: int, + *, + seed: Optional[int] = None) -> List[Any]: + """Randomly samples N elements from the dataset. + + Examples: + >>> import ray + >>> ds = ray.data.range(100) # doctest: +SKIP + >>> ds.random_sample(5) # doctest: +SKIP + >>> # Sample this dataset with a fixed random seed. + >>> ds.random_sample(5, seed=12345) # doctest: +SKIP + + + Args: + number: The number of elements to sample from the dataset. + + seed: Seeds the python random pRNG generator. + + Returns: + N elements from the shuffled dataset. + """ + import random + idx = random.randint(0, self.num_blocks()) + if idx + number >= self.num_blocks(): + idx = self.num_blocks() - number + spliced = self.split_at_indices([idx, idx + number]) + sample_choice = spliced[1] + return sample_choice.take(number) + def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None ) -> List["Dataset[T]"]: @@ -889,8 +918,8 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: actors_state = ray.state.actors() return { actor: actors_state.get(actor._actor_id.hex(), {}) - .get("Address", {}) - .get("NodeID") + .get("Address", {}) + .get("NodeID") for actor in actors } From ade595990fcf31edbe0ca83b4f5f4435aa1c5143 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 11:14:14 +0800 Subject: [PATCH 02/32] Ran scripts/format.sh --- python/ray/data/dataset.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 534f24c7cc9d..4511c8b8f857 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -613,9 +613,7 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): ) return Dataset(plan, self._epoch, self._lazy) - def random_sample(self, number: int, - *, - seed: Optional[int] = None) -> List[Any]: + def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any]: """Randomly samples N elements from the dataset. Examples: @@ -635,6 +633,7 @@ def random_sample(self, number: int, N elements from the shuffled dataset. """ import random + idx = random.randint(0, self.num_blocks()) if idx + number >= self.num_blocks(): idx = self.num_blocks() - number @@ -918,8 +917,8 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: actors_state = ray.state.actors() return { actor: actors_state.get(actor._actor_id.hex(), {}) - .get("Address", {}) - .get("NodeID") + .get("Address", {}) + .get("NodeID") for actor in actors } From d54f8ff6c1dcb6017ad50145c45d5b354adedbbb Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 11:56:49 +0800 Subject: [PATCH 03/32] Make random_sample more random --- python/ray/data/dataset.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4511c8b8f857..0c165f220687 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -630,16 +630,20 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] seed: Seeds the python random pRNG generator. Returns: - N elements from the shuffled dataset. + N elements, randomly sampled from the dataset. """ import random - idx = random.randint(0, self.num_blocks()) - if idx + number >= self.num_blocks(): - idx = self.num_blocks() - number - spliced = self.split_at_indices([idx, idx + number]) - sample_choice = spliced[1] - return sample_choice.take(number) + if seed: + random.seed(seed) + + rows = self._meta_count() + + n_required = rows // self.num_blocks() + + sample_population = self.map_batches(lambda batch: random.sample(batch, n_required)) + + return sample_population.take(number) def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None From 93f9dfb246a63dbb74d4175afde027e137d0ac3b Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 12:03:18 +0800 Subject: [PATCH 04/32] Add some dataset validity checks --- python/ray/data/dataset.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0c165f220687..0716e9f5de2a 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -634,14 +634,21 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] """ import random + if self.num_blocks() == 0: + raise ValueError("Cannot from an empty dataset") + + if number < 1: + raise ValueError("Cannot sample less than 1 element.") + if seed: random.seed(seed) - rows = self._meta_count() - - n_required = rows // self.num_blocks() + def process_batch(batch): + rows = self._meta_count() + n_required = rows // self.num_blocks() + return random.sample(batch, n_required) - sample_population = self.map_batches(lambda batch: random.sample(batch, n_required)) + sample_population = self.map_batches(process_batch) return sample_population.take(number) From 098426c0e6f9dcf1f1460e46debdd176891d3b5a Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 12:19:06 +0800 Subject: [PATCH 05/32] Fix random_sample() for non-list types --- python/ray/data/dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0716e9f5de2a..12808851808e 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -646,6 +646,10 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] def process_batch(batch): rows = self._meta_count() n_required = rows // self.num_blocks() + + if not isinstance(batch, list): + # Should handle dataframes and tensors + return batch.sample(n_required) return random.sample(batch, n_required) sample_population = self.map_batches(process_batch) From 72da35d110d59dc60e3faab17117275281e98690 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 12:42:01 +0800 Subject: [PATCH 06/32] Account for possibly attempting to sample n from a batch for x elements, n > x --- python/ray/data/dataset.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 12808851808e..c049641e2f7f 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -640,17 +640,23 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] if number < 1: raise ValueError("Cannot sample less than 1 element.") + count = self._meta_count() + + if number > count: + raise ValueError(f"Cannot sample more elements than there are in the dataset") + if seed: random.seed(seed) - def process_batch(batch): - rows = self._meta_count() - n_required = rows // self.num_blocks() + n_required = number // self.num_blocks() + n_required += 1 if number % self.num_blocks() else 0 + def process_batch(batch): if not isinstance(batch, list): # Should handle dataframes and tensors return batch.sample(n_required) - return random.sample(batch, n_required) + # Prevent sampling more than the batch can handle + return random.sample(batch, min(len(batch), n_required)) sample_population = self.map_batches(process_batch) From d76fb07e74ac073232664e0c17a7f0c0aac4f29c Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 5 May 2022 12:43:59 +0800 Subject: [PATCH 07/32] Run scripts/format.sh --- python/ray/data/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c049641e2f7f..e92aa0654ed2 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -643,7 +643,9 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] count = self._meta_count() if number > count: - raise ValueError(f"Cannot sample more elements than there are in the dataset") + raise ValueError( + "Cannot sample more elements than there are in the dataset" + ) if seed: random.seed(seed) From 61791dd43d35c7f47bfd2e35b3d8b0344e1f2997 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 11:15:52 +0800 Subject: [PATCH 08/32] Update sampling algorithm, updated documentation to explain sampling method --- python/ray/data/dataset.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e92aa0654ed2..e2e5903e5c53 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -616,6 +616,10 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any]: """Randomly samples N elements from the dataset. + This uniformly samples elements from the dataset. + From each block, n elements are taken uniformly where n is proportionate to the fraction of the rows in that block + as compared to the total number of rows in the dataset + Examples: >>> import ray >>> ds = ray.data.range(100) # doctest: +SKIP @@ -633,6 +637,7 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] N elements, randomly sampled from the dataset. """ import random + import math if self.num_blocks() == 0: raise ValueError("Cannot from an empty dataset") @@ -650,15 +655,32 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] if seed: random.seed(seed) - n_required = number // self.num_blocks() - n_required += 1 if number % self.num_blocks() else 0 - def process_batch(batch): + """ + Processes a batch of inputs + Args: + batch: The batch to process + + Returns: + Randomly sampled elements from the batch + This algorithm uniformly samples elements from the batch based on how many rows that + batch contains with respect to the total number of rows + """ + + # Sample size algorithm: + # sample_size_for_this_batch = ceiling ( + # (rows_in_this_batch / + # total_rows) * samples_wanted + # ) + + sample_size = (len(batch) / count) * number + sample_size = math.ceil(sample_size) + if not isinstance(batch, list): - # Should handle dataframes and tensors - return batch.sample(n_required) + # Provides handling for dataframes and tensors + return batch.sample(sample_size) # Prevent sampling more than the batch can handle - return random.sample(batch, min(len(batch), n_required)) + return random.sample(batch, min(len(batch), sample_size)) sample_population = self.map_batches(process_batch) @@ -940,8 +962,8 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: actors_state = ray.state.actors() return { actor: actors_state.get(actor._actor_id.hex(), {}) - .get("Address", {}) - .get("NodeID") + .get("Address", {}) + .get("NodeID") for actor in actors } From c7b9f55d295c8b075f294b3a0e8171817804cd50 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 11:16:53 +0800 Subject: [PATCH 09/32] Run format script --- python/ray/data/dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e2e5903e5c53..641f2f5d45c9 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -617,8 +617,10 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] """Randomly samples N elements from the dataset. This uniformly samples elements from the dataset. - From each block, n elements are taken uniformly where n is proportionate to the fraction of the rows in that block - as compared to the total number of rows in the dataset + + From each block, n elements are taken uniformly where n is proportionate + to the fraction of the rows in that block as compared to the total number + of rows in the dataset. The result is truncated to *number* elements. Examples: >>> import ray @@ -962,8 +964,8 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]: actors_state = ray.state.actors() return { actor: actors_state.get(actor._actor_id.hex(), {}) - .get("Address", {}) - .get("NodeID") + .get("Address", {}) + .get("NodeID") for actor in actors } From cfa92911d88f2581f980558918d1f38d2a59304e Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 11:29:20 +0800 Subject: [PATCH 10/32] Add a random_shuffle on the sample_population since .take() will leave it less random --- python/ray/data/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 641f2f5d45c9..3eca6e979bc8 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -685,6 +685,7 @@ def process_batch(batch): return random.sample(batch, min(len(batch), sample_size)) sample_population = self.map_batches(process_batch) + sample_population.random_shuffle(seed=seed, num_blocks=None) return sample_population.take(number) From 52d306364dede505f286e0ab9bcbfebd1e636f7f Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 11:44:53 +0800 Subject: [PATCH 11/32] run format script --- python/ray/data/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3eca6e979bc8..ba2df407c2aa 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -665,8 +665,9 @@ def process_batch(batch): Returns: Randomly sampled elements from the batch - This algorithm uniformly samples elements from the batch based on how many rows that - batch contains with respect to the total number of rows + This algorithm uniformly samples elements from the batch based + on how many rows that batch contains with respect to the total + number of rows """ # Sample size algorithm: From b970866ab48b8bc3f0dd95daed837c15c1066f83 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 11:45:35 +0800 Subject: [PATCH 12/32] Add tests for random sampling --- python/ray/data/tests/test_dataset.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index a2e45ced3718..fa0b3a45bdb2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3566,6 +3566,38 @@ def get_node_id(): assert set(locations) == {node1_id, node2_id} +def test_random_sample(): + ds = ray.data.range(10, parallelism=2) + r1 = ds.random_sample(4) + assert len(r1) == 4 + + # "weird" datasets + ds1 = ray.data.range(1, parallelism=1) + ds2 = ray.data.range(2, parallelism=1) + ds3 = ray.data.range(3, parallelism=1) + ds = ds1.union(ds2).union(ds3) + r2 = ds.random_sample(3) + assert len(r2) == 3 + + +def test_random_sample_spread(): + # TODO: Check for non-contiguity + pass + + +def test_random_sample_checks(): + with pytest.raises(ValueError) as e_info: + # Obviously, you cannot sample -1 elements + ray.data.range(1).random_sample(-1) + + # Neither should you be able to sample an empty dataset + ray.data.range(0).random_sample(1) + + # No sampling more elements than the dataset contains + + ray.data.range(2).random_sample(3) + + def test_parquet_read_spread(ray_start_cluster, tmp_path): cluster = ray_start_cluster cluster.add_node( From c31f88e32630e66273f3b8e6bf64fde7494cd47f Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 12:29:50 +0800 Subject: [PATCH 13/32] Add another sampling strategy --- python/ray/data/dataset.py | 93 ++++++++++++++++----------- python/ray/data/tests/test_dataset.py | 18 +++++- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index ba2df407c2aa..7761ff4df3e9 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -613,14 +613,14 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): ) return Dataset(plan, self._epoch, self._lazy) - def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any]: + def random_sample( + self, number: int, *, seed: Optional[int] = None, sampling_strategy=0 + ) -> List[Any]: """Randomly samples N elements from the dataset. This uniformly samples elements from the dataset. - From each block, n elements are taken uniformly where n is proportionate - to the fraction of the rows in that block as compared to the total number - of rows in the dataset. The result is truncated to *number* elements. + Examples: >>> import ray @@ -635,6 +635,12 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] seed: Seeds the python random pRNG generator. + sampling_strategy: The sampling strategy to use + 0 is the default. From each block, n elements are taken uniformly where n is proportionate + to the fraction of the rows in that block as compared to the total number + of rows in the dataset. The result is truncated to *number* elements. + 1 generates N indices to sample the data from by uniformly sampling the indices from range [0, num_rows-1] + Returns: N elements, randomly sampled from the dataset. """ @@ -657,38 +663,53 @@ def random_sample(self, number: int, *, seed: Optional[int] = None) -> List[Any] if seed: random.seed(seed) - def process_batch(batch): - """ - Processes a batch of inputs - Args: - batch: The batch to process - - Returns: - Randomly sampled elements from the batch - This algorithm uniformly samples elements from the batch based - on how many rows that batch contains with respect to the total - number of rows - """ - - # Sample size algorithm: - # sample_size_for_this_batch = ceiling ( - # (rows_in_this_batch / - # total_rows) * samples_wanted - # ) - - sample_size = (len(batch) / count) * number - sample_size = math.ceil(sample_size) - - if not isinstance(batch, list): - # Provides handling for dataframes and tensors - return batch.sample(sample_size) - # Prevent sampling more than the batch can handle - return random.sample(batch, min(len(batch), sample_size)) - - sample_population = self.map_batches(process_batch) - sample_population.random_shuffle(seed=seed, num_blocks=None) - - return sample_population.take(number) + if sampling_strategy not in [0, 1]: + raise ValueError("Sampling strategy must be 0 or 1") + + if sampling_strategy == 0: + # Uniform sampling strategy + def process_batch(batch): + """ + Processes a batch of inputs + Args: + batch: The batch to process + + Returns: + Randomly sampled elements from the batch + This algorithm uniformly samples elements from the batch based + on how many rows that batch contains with respect to the total + number of rows + """ + + # Sample size algorithm: + # sample_size_for_this_batch = ceiling ( + # (rows_in_this_batch / + # total_rows) * samples_wanted + # ) + + sample_size = (len(batch) / count) * number + sample_size = math.ceil(sample_size) + + if not isinstance(batch, list): + # Provides handling for dataframes and tensors + return batch.sample(sample_size) + # Prevent sampling more than the batch can handle + return random.sample(batch, min(len(batch), sample_size)) + + sample_population = self.map_batches(process_batch) + sample_population.random_shuffle(seed=seed, num_blocks=None) + + return sample_population.take(number) + elif sampling_strategy == 1: + # Indices generating strategy + # TODO: This strategy may fail if the block size is 0 + indices = random.sample(range(0, count), number) + indices.sort() + spliced = self.split_at_indices(indices)[1:] + output = [] + for ds in spliced: + output.append(ds.take(1)[0]) + return output def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index fa0b3a45bdb2..c1fe3c5c9828 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3569,18 +3569,32 @@ def get_node_id(): def test_random_sample(): ds = ray.data.range(10, parallelism=2) r1 = ds.random_sample(4) + r2 = ds.random_sample(4, sampling_strategy=1) assert len(r1) == 4 + assert len(r2) == 4 # "weird" datasets ds1 = ray.data.range(1, parallelism=1) ds2 = ray.data.range(2, parallelism=1) ds3 = ray.data.range(3, parallelism=1) + # noinspection PyTypeChecker ds = ds1.union(ds2).union(ds3) + r1 = ds.random_sample(4, sampling_strategy=0) r2 = ds.random_sample(3) + assert len(r1) == 4 assert len(r2) == 3 def test_random_sample_spread(): + def is_continuous(x): + prev = x[0] + for e in x: + if e == prev + 1: + return True + prev = e + return False + + ds = ray.data.range(50) # TODO: Check for non-contiguity pass @@ -3594,9 +3608,11 @@ def test_random_sample_checks(): ray.data.range(0).random_sample(1) # No sampling more elements than the dataset contains - ray.data.range(2).random_sample(3) + # Invalid sampling strategy + ray.data.range(1).random_sample(1, sampling_strategy=42) + def test_parquet_read_spread(ray_start_cluster, tmp_path): cluster = ray_start_cluster From e5b4bcc1372def6c39a905cbcb0a7cf3f5daee21 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 6 May 2022 12:33:56 +0800 Subject: [PATCH 14/32] Add tests for other dataset types --- python/ray/data/tests/test_dataset.py | 137 ++++++++++++++------------ 1 file changed, 72 insertions(+), 65 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index c1fe3c5c9828..ff7f0285821b 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -442,8 +442,8 @@ def test_tensors(ray_start_regular_shared): # Pandas conversion. res = ( ray.data.range_tensor(10) - .map_batches(lambda t: t + 2, batch_format="pandas") - .take(2) + .map_batches(lambda t: t + 2, batch_format="pandas") + .take(2) ) assert str(res) == "[{'value': array([2])}, {'value': array([3])}]" @@ -1287,7 +1287,7 @@ def test_sliding_window(): assert len(windows) == len(arr) - window_size + 1 assert all(len(window) == window_size for window in windows) assert all( - list(window) == arr[i : i + window_size] for i, window in enumerate(windows) + list(window) == arr[i: i + window_size] for i, window in enumerate(windows) ) # Test window size larger than iterable length. @@ -2021,9 +2021,9 @@ def test_groupby_arrow(ray_start_regular_shared): # Test empty dataset. agg_ds = ( ray.data.range_arrow(10) - .filter(lambda r: r["value"] > 10) - .groupby("value") - .count() + .filter(lambda r: r["value"] > 10) + .groupby("value") + .count() ) assert agg_ds.count() == 0 @@ -2067,8 +2067,8 @@ def test_groupby_agg_name_conflict(ray_start_regular_shared, num_parts): xs = list(range(100)) grouped_ds = ( ray.data.from_items([{"A": (x % 3), "B": x} for x in xs]) - .repartition(num_parts) - .groupby("A") + .repartition(num_parts) + .groupby("A") ) agg_ds = grouped_ds.aggregate( AggregateFn( @@ -2737,9 +2737,9 @@ def test_groupby_arrow_multi_agg(ray_start_regular_shared, num_parts): df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) agg_ds = ( ray.data.from_pandas(df) - .repartition(num_parts) - .groupby("A") - .aggregate( + .repartition(num_parts) + .groupby("A") + .aggregate( Count(), Sum("B"), Min("B"), @@ -2764,8 +2764,8 @@ def test_groupby_arrow_multi_agg(ray_start_regular_shared, num_parts): result_row = ( ray.data.from_pandas(df) - .repartition(num_parts) - .aggregate( + .repartition(num_parts) + .aggregate( Sum("A"), Min("A"), Max("A"), @@ -2881,8 +2881,8 @@ def test_groupby_simple_sum(ray_start_regular_shared, num_parts): # Test built-in sum aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.sum() assert nan_agg_ds.count() == 3 @@ -2902,9 +2902,9 @@ def test_groupby_simple_sum(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .sum() + .repartition(num_parts) + .groupby(lambda x: 0) + .sum() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -2983,9 +2983,9 @@ def test_groupby_map_groups_returning_empty_result(ray_start_regular_shared, num xs = list(range(100)) mapped = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .map_groups(lambda x: []) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .map_groups(lambda x: []) ) assert mapped.count() == 0 assert mapped.take_all() == [] @@ -3000,9 +3000,9 @@ def test_groupby_map_groups_for_list(ray_start_regular_shared, num_parts): random.shuffle(xs) mapped = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .map_groups(lambda x: [min(x) * min(x)]) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .map_groups(lambda x: [min(x) * min(x)]) ) assert mapped.count() == 3 assert mapped.take_all() == [0, 1, 4] @@ -3072,8 +3072,8 @@ def test_groupby_simple_min(ray_start_regular_shared, num_parts): # Test built-in min aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.min() assert nan_agg_ds.count() == 3 @@ -3085,9 +3085,9 @@ def test_groupby_simple_min(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .min() + .repartition(num_parts) + .groupby(lambda x: 0) + .min() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3123,8 +3123,8 @@ def test_groupby_simple_max(ray_start_regular_shared, num_parts): # Test built-in max aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.max() assert nan_agg_ds.count() == 3 @@ -3136,9 +3136,9 @@ def test_groupby_simple_max(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .max() + .repartition(num_parts) + .groupby(lambda x: 0) + .max() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3174,8 +3174,8 @@ def test_groupby_simple_mean(ray_start_regular_shared, num_parts): # Test built-in mean aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.mean() assert nan_agg_ds.count() == 3 @@ -3195,9 +3195,9 @@ def test_groupby_simple_mean(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .mean() + .repartition(num_parts) + .groupby(lambda x: 0) + .mean() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3239,9 +3239,9 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # ddof of 0 agg_ds = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .std(ddof=0) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .std(ddof=0) ) assert agg_ds.count() == 3 df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) @@ -3255,8 +3255,8 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # Test built-in std aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.std() assert nan_agg_ds.count() == 3 @@ -3280,9 +3280,9 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .std(ignore_nulls=False) + .repartition(num_parts) + .groupby(lambda x: 0) + .std(ignore_nulls=False) ) assert nan_agg_ds.count() == 1 expected = pd.Series([None], name="B") @@ -3328,9 +3328,9 @@ def test_groupby_simple_multilambda(ray_start_regular_shared, num_parts): random.shuffle(xs) agg_ds = ( ray.data.from_items([[x, 2 * x] for x in xs]) - .repartition(num_parts) - .groupby(lambda x: x[0] % 3) - .mean([lambda x: x[0], lambda x: x[1]]) + .repartition(num_parts) + .groupby(lambda x: x[0] % 3) + .mean([lambda x: x[0], lambda x: x[1]]) ) assert agg_ds.count() == 3 assert agg_ds.sort(key=lambda r: r[0]).take(3) == [ @@ -3357,9 +3357,9 @@ def test_groupby_simple_multi_agg(ray_start_regular_shared, num_parts): df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) agg_ds = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .aggregate( + .repartition(num_parts) + .groupby(lambda x: x % 3) + .aggregate( Count(), Sum(), Min(), @@ -3396,8 +3396,8 @@ def test_groupby_simple_multi_agg(ray_start_regular_shared, num_parts): # Test built-in global multi-aggregation result_row = ( ray.data.from_items(xs) - .repartition(num_parts) - .aggregate( + .repartition(num_parts) + .aggregate( Sum(), Min(), Max(), @@ -3497,7 +3497,7 @@ def test_random_shuffle_check_random(shutdown_only): ds = ray.data.from_items(items, parallelism=num_files) out = ds.random_shuffle().take(num_files * num_rows) for i in range(num_files): - part = out[i * num_rows : (i + 1) * num_rows] + part = out[i * num_rows: (i + 1) * num_rows] seen = set() num_contiguous = 1 prev = -1 @@ -3523,7 +3523,7 @@ def test_random_shuffle_check_random(shutdown_only): ds = ray.data.from_items(items, parallelism=num_files) out = ds.random_shuffle().take(num_files * num_rows) for i in range(num_files): - part = out[i * num_rows : (i + 1) * num_rows] + part = out[i * num_rows: (i + 1) * num_rows] num_increasing = 0 prev = -1 for x in part: @@ -3567,11 +3567,21 @@ def get_node_id(): def test_random_sample(): + + def test(dataset, sample_size=4): + r1 = ds.random_sample(sample_size) + r2 = ds.random_sample(sample_size, sampling_strategy=1) + assert len(r1) == 4 + assert len(r2) == 4 + ds = ray.data.range(10, parallelism=2) - r1 = ds.random_sample(4) - r2 = ds.random_sample(4, sampling_strategy=1) - assert len(r1) == 4 - assert len(r2) == 4 + test(ds) + + ds = ray.data.range_arrow(10, parallelism=2) + test(ds) + + ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2)) + test(ds) # "weird" datasets ds1 = ray.data.range(1, parallelism=1) @@ -3579,10 +3589,7 @@ def test_random_sample(): ds3 = ray.data.range(3, parallelism=1) # noinspection PyTypeChecker ds = ds1.union(ds2).union(ds3) - r1 = ds.random_sample(4, sampling_strategy=0) - r2 = ds.random_sample(3) - assert len(r1) == 4 - assert len(r2) == 3 + test(ds) def test_random_sample_spread(): From e25efa9559f0170a1130c6f9bef43b70f8d0988a Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Sun, 8 May 2022 09:46:56 +0800 Subject: [PATCH 15/32] Working but the tests may fail since random_sample does not guarantee the EXACT fraction to be returned --- python/ray/data/dataset.py | 98 ++++++++------------------- python/ray/data/tests/test_dataset.py | 4 +- 2 files changed, 31 insertions(+), 71 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7761ff4df3e9..92232261d32b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -614,11 +614,12 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): return Dataset(plan, self._epoch, self._lazy) def random_sample( - self, number: int, *, seed: Optional[int] = None, sampling_strategy=0 - ) -> List[Any]: - """Randomly samples N elements from the dataset. + self, fraction: float, *, seed: Optional[int] = None + ) -> "Dataset[T]": + """Randomly samples a fraction of the elements of this dataset. - This uniformly samples elements from the dataset. + Note that the fraction sampled is only approximate, and may not be + exactly the fraction specified. @@ -631,18 +632,12 @@ def random_sample( Args: - number: The number of elements to sample from the dataset. + fraction: The fraction of elements to sample. seed: Seeds the python random pRNG generator. - sampling_strategy: The sampling strategy to use - 0 is the default. From each block, n elements are taken uniformly where n is proportionate - to the fraction of the rows in that block as compared to the total number - of rows in the dataset. The result is truncated to *number* elements. - 1 generates N indices to sample the data from by uniformly sampling the indices from range [0, num_rows-1] - Returns: - N elements, randomly sampled from the dataset. + Returns a dataset with approximately (fraction) of the elements of the original dataset """ import random import math @@ -650,66 +645,33 @@ def random_sample( if self.num_blocks() == 0: raise ValueError("Cannot from an empty dataset") - if number < 1: - raise ValueError("Cannot sample less than 1 element.") - - count = self._meta_count() - - if number > count: - raise ValueError( - "Cannot sample more elements than there are in the dataset" - ) + if fraction < 0 or fraction > 1: + raise ValueError("Fraction must be between 0 and 1") if seed: random.seed(seed) - if sampling_strategy not in [0, 1]: - raise ValueError("Sampling strategy must be 0 or 1") - - if sampling_strategy == 0: - # Uniform sampling strategy - def process_batch(batch): - """ - Processes a batch of inputs - Args: - batch: The batch to process - - Returns: - Randomly sampled elements from the batch - This algorithm uniformly samples elements from the batch based - on how many rows that batch contains with respect to the total - number of rows - """ - - # Sample size algorithm: - # sample_size_for_this_batch = ceiling ( - # (rows_in_this_batch / - # total_rows) * samples_wanted - # ) - - sample_size = (len(batch) / count) * number - sample_size = math.ceil(sample_size) - - if not isinstance(batch, list): - # Provides handling for dataframes and tensors - return batch.sample(sample_size) - # Prevent sampling more than the batch can handle - return random.sample(batch, min(len(batch), sample_size)) - - sample_population = self.map_batches(process_batch) - sample_population.random_shuffle(seed=seed, num_blocks=None) - - return sample_population.take(number) - elif sampling_strategy == 1: - # Indices generating strategy - # TODO: This strategy may fail if the block size is 0 - indices = random.sample(range(0, count), number) - indices.sort() - spliced = self.split_at_indices(indices)[1:] - output = [] - for ds in spliced: - output.append(ds.take(1)[0]) - return output + def process_batch(batch): + """ + Processes a batch of inputs + Args: + batch: The batch to process + + Returns: + Randomly sampled elements from the batch + """ + + if isinstance(batch, list): + s_batch = sorted(batch) + weights = [random.random() for _ in range(len(s_batch))] + _result = [] + for i, w in enumerate(weights): + if w <= fraction: + _result.append(s_batch[i]) + return _result + return batch.sample(frac=fraction) + + return self.map_batches(process_batch) def split( self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index ff7f0285821b..9282fb1e2ffa 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3570,9 +3570,7 @@ def test_random_sample(): def test(dataset, sample_size=4): r1 = ds.random_sample(sample_size) - r2 = ds.random_sample(sample_size, sampling_strategy=1) - assert len(r1) == 4 - assert len(r2) == 4 + assert len(r1) == sample_size ds = ray.data.range(10, parallelism=2) test(ds) From 02f2724a427efd23ace59f851ba2f17d7ded3fdb Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Mon, 9 May 2022 09:53:26 +0800 Subject: [PATCH 16/32] Fix tests --- python/ray/data/tests/test_dataset.py | 33 +++++---------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 9282fb1e2ffa..fb6326c5750c 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3567,10 +3567,11 @@ def get_node_id(): def test_random_sample(): + import math - def test(dataset, sample_size=4): - r1 = ds.random_sample(sample_size) - assert len(r1) == sample_size + def test(dataset, sample_percent=0.5): + r1 = ds.random_sample(sample_percent) + assert math.isclose(r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2) ds = ray.data.range(10, parallelism=2) test(ds) @@ -3590,33 +3591,11 @@ def test(dataset, sample_size=4): test(ds) -def test_random_sample_spread(): - def is_continuous(x): - prev = x[0] - for e in x: - if e == prev + 1: - return True - prev = e - return False - - ds = ray.data.range(50) - # TODO: Check for non-contiguity - pass - - def test_random_sample_checks(): with pytest.raises(ValueError) as e_info: - # Obviously, you cannot sample -1 elements ray.data.range(1).random_sample(-1) - - # Neither should you be able to sample an empty dataset - ray.data.range(0).random_sample(1) - - # No sampling more elements than the dataset contains - ray.data.range(2).random_sample(3) - - # Invalid sampling strategy - ray.data.range(1).random_sample(1, sampling_strategy=42) + ray.data.range(0).random_sample(0.2) + ray.data.range(1).random_sample(10) def test_parquet_read_spread(ray_start_cluster, tmp_path): From 894114c1d0c7572d60fc7a335a4fb270ad0f1104 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Tue, 10 May 2022 15:11:19 +0800 Subject: [PATCH 17/32] Undo IDE auto reformat --- python/ray/data/tests/test_dataset.py | 118 +++++++++++++------------- 1 file changed, 60 insertions(+), 58 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index fb6326c5750c..a4f689a0d268 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -442,8 +442,8 @@ def test_tensors(ray_start_regular_shared): # Pandas conversion. res = ( ray.data.range_tensor(10) - .map_batches(lambda t: t + 2, batch_format="pandas") - .take(2) + .map_batches(lambda t: t + 2, batch_format="pandas") + .take(2) ) assert str(res) == "[{'value': array([2])}, {'value': array([3])}]" @@ -1287,7 +1287,7 @@ def test_sliding_window(): assert len(windows) == len(arr) - window_size + 1 assert all(len(window) == window_size for window in windows) assert all( - list(window) == arr[i: i + window_size] for i, window in enumerate(windows) + list(window) == arr[i : i + window_size] for i, window in enumerate(windows) ) # Test window size larger than iterable length. @@ -2021,9 +2021,9 @@ def test_groupby_arrow(ray_start_regular_shared): # Test empty dataset. agg_ds = ( ray.data.range_arrow(10) - .filter(lambda r: r["value"] > 10) - .groupby("value") - .count() + .filter(lambda r: r["value"] > 10) + .groupby("value") + .count() ) assert agg_ds.count() == 0 @@ -2067,8 +2067,8 @@ def test_groupby_agg_name_conflict(ray_start_regular_shared, num_parts): xs = list(range(100)) grouped_ds = ( ray.data.from_items([{"A": (x % 3), "B": x} for x in xs]) - .repartition(num_parts) - .groupby("A") + .repartition(num_parts) + .groupby("A") ) agg_ds = grouped_ds.aggregate( AggregateFn( @@ -2737,9 +2737,9 @@ def test_groupby_arrow_multi_agg(ray_start_regular_shared, num_parts): df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) agg_ds = ( ray.data.from_pandas(df) - .repartition(num_parts) - .groupby("A") - .aggregate( + .repartition(num_parts) + .groupby("A") + .aggregate( Count(), Sum("B"), Min("B"), @@ -2764,8 +2764,8 @@ def test_groupby_arrow_multi_agg(ray_start_regular_shared, num_parts): result_row = ( ray.data.from_pandas(df) - .repartition(num_parts) - .aggregate( + .repartition(num_parts) + .aggregate( Sum("A"), Min("A"), Max("A"), @@ -2881,8 +2881,8 @@ def test_groupby_simple_sum(ray_start_regular_shared, num_parts): # Test built-in sum aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.sum() assert nan_agg_ds.count() == 3 @@ -2902,9 +2902,9 @@ def test_groupby_simple_sum(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .sum() + .repartition(num_parts) + .groupby(lambda x: 0) + .sum() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -2983,9 +2983,9 @@ def test_groupby_map_groups_returning_empty_result(ray_start_regular_shared, num xs = list(range(100)) mapped = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .map_groups(lambda x: []) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .map_groups(lambda x: []) ) assert mapped.count() == 0 assert mapped.take_all() == [] @@ -3000,9 +3000,9 @@ def test_groupby_map_groups_for_list(ray_start_regular_shared, num_parts): random.shuffle(xs) mapped = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .map_groups(lambda x: [min(x) * min(x)]) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .map_groups(lambda x: [min(x) * min(x)]) ) assert mapped.count() == 3 assert mapped.take_all() == [0, 1, 4] @@ -3072,8 +3072,8 @@ def test_groupby_simple_min(ray_start_regular_shared, num_parts): # Test built-in min aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.min() assert nan_agg_ds.count() == 3 @@ -3085,9 +3085,9 @@ def test_groupby_simple_min(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .min() + .repartition(num_parts) + .groupby(lambda x: 0) + .min() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3123,8 +3123,8 @@ def test_groupby_simple_max(ray_start_regular_shared, num_parts): # Test built-in max aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.max() assert nan_agg_ds.count() == 3 @@ -3136,9 +3136,9 @@ def test_groupby_simple_max(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .max() + .repartition(num_parts) + .groupby(lambda x: 0) + .max() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3174,8 +3174,8 @@ def test_groupby_simple_mean(ray_start_regular_shared, num_parts): # Test built-in mean aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.mean() assert nan_agg_ds.count() == 3 @@ -3195,9 +3195,9 @@ def test_groupby_simple_mean(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .mean() + .repartition(num_parts) + .groupby(lambda x: 0) + .mean() ) assert nan_agg_ds.count() == 1 assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)] @@ -3239,9 +3239,9 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # ddof of 0 agg_ds = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .std(ddof=0) + .repartition(num_parts) + .groupby(lambda x: x % 3) + .std(ddof=0) ) assert agg_ds.count() == 3 df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) @@ -3255,8 +3255,8 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # Test built-in std aggregation with nans nan_grouped_ds = ( ray.data.from_items(xs + [None]) - .repartition(num_parts) - .groupby(lambda x: int(x or 0) % 3) + .repartition(num_parts) + .groupby(lambda x: int(x or 0) % 3) ) nan_agg_ds = nan_grouped_ds.std() assert nan_agg_ds.count() == 3 @@ -3280,9 +3280,9 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts): # Test all nans nan_agg_ds = ( ray.data.from_items([None] * len(xs)) - .repartition(num_parts) - .groupby(lambda x: 0) - .std(ignore_nulls=False) + .repartition(num_parts) + .groupby(lambda x: 0) + .std(ignore_nulls=False) ) assert nan_agg_ds.count() == 1 expected = pd.Series([None], name="B") @@ -3328,9 +3328,9 @@ def test_groupby_simple_multilambda(ray_start_regular_shared, num_parts): random.shuffle(xs) agg_ds = ( ray.data.from_items([[x, 2 * x] for x in xs]) - .repartition(num_parts) - .groupby(lambda x: x[0] % 3) - .mean([lambda x: x[0], lambda x: x[1]]) + .repartition(num_parts) + .groupby(lambda x: x[0] % 3) + .mean([lambda x: x[0], lambda x: x[1]]) ) assert agg_ds.count() == 3 assert agg_ds.sort(key=lambda r: r[0]).take(3) == [ @@ -3357,9 +3357,9 @@ def test_groupby_simple_multi_agg(ray_start_regular_shared, num_parts): df = pd.DataFrame({"A": [x % 3 for x in xs], "B": xs}) agg_ds = ( ray.data.from_items(xs) - .repartition(num_parts) - .groupby(lambda x: x % 3) - .aggregate( + .repartition(num_parts) + .groupby(lambda x: x % 3) + .aggregate( Count(), Sum(), Min(), @@ -3396,8 +3396,8 @@ def test_groupby_simple_multi_agg(ray_start_regular_shared, num_parts): # Test built-in global multi-aggregation result_row = ( ray.data.from_items(xs) - .repartition(num_parts) - .aggregate( + .repartition(num_parts) + .aggregate( Sum(), Min(), Max(), @@ -3497,7 +3497,7 @@ def test_random_shuffle_check_random(shutdown_only): ds = ray.data.from_items(items, parallelism=num_files) out = ds.random_shuffle().take(num_files * num_rows) for i in range(num_files): - part = out[i * num_rows: (i + 1) * num_rows] + part = out[i * num_rows : (i + 1) * num_rows] seen = set() num_contiguous = 1 prev = -1 @@ -3523,7 +3523,7 @@ def test_random_shuffle_check_random(shutdown_only): ds = ray.data.from_items(items, parallelism=num_files) out = ds.random_shuffle().take(num_files * num_rows) for i in range(num_files): - part = out[i * num_rows: (i + 1) * num_rows] + part = out[i * num_rows : (i + 1) * num_rows] num_increasing = 0 prev = -1 for x in part: @@ -3571,7 +3571,9 @@ def test_random_sample(): def test(dataset, sample_percent=0.5): r1 = ds.random_sample(sample_percent) - assert math.isclose(r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2) + assert math.isclose( + r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2 + ) ds = ray.data.range(10, parallelism=2) test(ds) From 7911dd3be38bace63a1be38c1f951ca4b7c61220 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Tue, 10 May 2022 15:22:07 +0800 Subject: [PATCH 18/32] Update sampling as per comment --- python/ray/data/dataset.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 92232261d32b..001a569136fd 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -616,12 +616,9 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): def random_sample( self, fraction: float, *, seed: Optional[int] = None ) -> "Dataset[T]": - """Randomly samples a fraction of the elements of this dataset. - - Note that the fraction sampled is only approximate, and may not be - exactly the fraction specified. - + """Randomly samples a fraction of the elements of this dataset by uniform sampling. + Note that the exact fraction of elements to sample is not guaranteed. Examples: >>> import ray @@ -637,10 +634,12 @@ def random_sample( seed: Seeds the python random pRNG generator. Returns: - Returns a dataset with approximately (fraction) of the elements of the original dataset + Returns a dataset with *fraction* of the elements of the original dataset """ import random import math + import pyarrow as pa + import pandas as pd if self.num_blocks() == 0: raise ValueError("Cannot from an empty dataset") @@ -662,14 +661,16 @@ def process_batch(batch): """ if isinstance(batch, list): - s_batch = sorted(batch) - weights = [random.random() for _ in range(len(s_batch))] - _result = [] - for i, w in enumerate(weights): - if w <= fraction: - _result.append(s_batch[i]) - return _result - return batch.sample(frac=fraction) + return random.sample(batch, math.ceil(len(batch) * fraction)) + if isinstance(batch, pa.Table): + # Generate a mask to select random indices + indices = random.sample(range(len(batch)), math.ceil(len(batch) * fraction)) + mask = [True if i in indices else False for i in range(len(batch))] + return batch.filter(mask) + if isinstance(batch, pd.DataFrame): + return batch.sample(frac=fraction) + + raise ValueError("Unsupported batch type") return self.map_batches(process_batch) From 088b2914960ad9dc3f6182521d784cf775ede20f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 11 May 2022 18:56:35 -0700 Subject: [PATCH 19/32] Update dataset.py --- python/ray/data/dataset.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 001a569136fd..2e2e3e48e9d7 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -616,25 +616,22 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): def random_sample( self, fraction: float, *, seed: Optional[int] = None ) -> "Dataset[T]": - """Randomly samples a fraction of the elements of this dataset by uniform sampling. + """Randomly samples a fraction of the elements of this dataset. - Note that the exact fraction of elements to sample is not guaranteed. + Note that the exact number of elements returned is not guaranteed. Examples: >>> import ray >>> ds = ray.data.range(100) # doctest: +SKIP - >>> ds.random_sample(5) # doctest: +SKIP - >>> # Sample this dataset with a fixed random seed. - >>> ds.random_sample(5, seed=12345) # doctest: +SKIP - + >>> ds.random_sample(0.1) # doctest: +SKIP + >>> ds.random_sample(0.2, seed=12345) # doctest: +SKIP Args: fraction: The fraction of elements to sample. - seed: Seeds the python random pRNG generator. Returns: - Returns a dataset with *fraction* of the elements of the original dataset + Returns a Dataset containing the sampled elements. """ import random import math @@ -642,24 +639,15 @@ def random_sample( import pandas as pd if self.num_blocks() == 0: - raise ValueError("Cannot from an empty dataset") + raise ValueError("Cannot sample from an empty dataset.") if fraction < 0 or fraction > 1: - raise ValueError("Fraction must be between 0 and 1") + raise ValueError("Fraction must be between 0 and 1.") if seed: random.seed(seed) def process_batch(batch): - """ - Processes a batch of inputs - Args: - batch: The batch to process - - Returns: - Randomly sampled elements from the batch - """ - if isinstance(batch, list): return random.sample(batch, math.ceil(len(batch) * fraction)) if isinstance(batch, pa.Table): @@ -670,7 +658,7 @@ def process_batch(batch): if isinstance(batch, pd.DataFrame): return batch.sample(frac=fraction) - raise ValueError("Unsupported batch type") + raise ValueError("Unsupported batch type: {}".format(type(batch))) return self.map_batches(process_batch) From 01d922f4b8ee89a410b6f485651e3a58156b1f1c Mon Sep 17 00:00:00 2001 From: Robert Date: Thu, 12 May 2022 11:58:25 +0800 Subject: [PATCH 20/32] Fix failing test Co-authored-by: Eric Liang --- python/ray/data/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index a4f689a0d268..8fec0e0135f2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3566,7 +3566,7 @@ def get_node_id(): assert set(locations) == {node1_id, node2_id} -def test_random_sample(): +def test_random_sample(ray_start_regular_shared): import math def test(dataset, sample_percent=0.5): From ef4bcae12770cb4bf6656829cd8338be84459b91 Mon Sep 17 00:00:00 2001 From: Robert Date: Thu, 12 May 2022 11:58:34 +0800 Subject: [PATCH 21/32] Fix failing test #2 Co-authored-by: Eric Liang --- python/ray/data/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 8fec0e0135f2..d149a3f4033d 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3593,7 +3593,7 @@ def test(dataset, sample_percent=0.5): test(ds) -def test_random_sample_checks(): +def test_random_sample_checks(ray_start_regular_shared): with pytest.raises(ValueError) as e_info: ray.data.range(1).random_sample(-1) ray.data.range(0).random_sample(0.2) From 9aa833776beb5d4650aecafd6eaef7afac385f4f Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 12 May 2022 13:54:05 +0800 Subject: [PATCH 22/32] Break up ValueError assertions --- python/ray/data/tests/test_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 0f57c75e1354..755a95141b37 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3595,8 +3595,13 @@ def test(dataset, sample_percent=0.5): def test_random_sample_checks(ray_start_regular_shared): with pytest.raises(ValueError) as e_info: + # Cannot sample -1 ray.data.range(1).random_sample(-1) + with pytest.raises(ValueError): + # Cannot sample from empty dataset ray.data.range(0).random_sample(0.2) + with pytest.raises(ValueError): + # Cannot sample fraction > 1 ray.data.range(1).random_sample(10) From 17e8e5cef666f1a77af4aaa22b3277d84d14a815 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 12 May 2022 13:55:46 +0800 Subject: [PATCH 23/32] Update documentation to reflect the number of items being returned --- python/ray/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 89e503203b7a..b15938f87232 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -618,7 +618,8 @@ def random_sample( ) -> "Dataset[T]": """Randomly samples a fraction of the elements of this dataset. - Note that the exact number of elements returned is not guaranteed. + Note that the exact number of elements returned is not guaranteed, + and that the number of elements being returned is roughly fraction * total_rows. Examples: >>> import ray From 46290c1a3ce1a9762254e9f9e7c9904f2a757bd0 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 12 May 2022 14:05:41 +0800 Subject: [PATCH 24/32] Add handling to address len(batch) * fraction < 1 --- python/ray/data/dataset.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index b15938f87232..d13f55d2c37a 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -649,13 +649,30 @@ def random_sample( random.seed(seed) def process_batch(batch): + strategy = 0 + probs = [] + if len(batch) * fraction < 1: + strategy = 1 + # Bernoulli sampler similar to Apache spark's sampler + probs = [random.random() for _ in range(len(batch))] + if isinstance(batch, list): - return random.sample(batch, math.ceil(len(batch) * fraction)) + if strategy == 0: + return random.sample(batch, math.ceil(len(batch) * fraction)) + else: + # Picks the item if the weight generated for that item <= fraction + return [batch[i] for i in range(len(batch)) if probs[i] <= fraction] + if isinstance(batch, pa.Table): - # Generate a mask to select random indices - indices = random.sample(range(len(batch)), math.ceil(len(batch) * fraction)) - mask = [True if i in indices else False for i in range(len(batch))] - return batch.filter(mask) + if strategy == 0: + # Generate a mask to select random indices + indices = random.sample(range(len(batch)), math.ceil(len(batch) * fraction)) + mask = [True if i in indices else False for i in range(len(batch))] + return batch.filter(mask) + else: + # Lets the item pass if the weight generated for that item <= fraction + mask = [True if p <= fraction else False for p in probs] + return batch.filter(mask) if isinstance(batch, pd.DataFrame): return batch.sample(frac=fraction) From c9c6eb3774c296c120f381b26a331656ed3dc7fa Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 12 May 2022 14:08:08 +0800 Subject: [PATCH 25/32] Test for 46290c1 --- python/ray/data/tests/test_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 755a95141b37..21e07f3e7c2b 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3592,9 +3592,14 @@ def test(dataset, sample_percent=0.5): ds = ds1.union(ds2).union(ds3) test(ds) + # Small datasets + + ds1 = ray.data.range(5, parallelism=5) + test(ds1) + def test_random_sample_checks(ray_start_regular_shared): - with pytest.raises(ValueError) as e_info: + with pytest.raises(ValueError): # Cannot sample -1 ray.data.range(1).random_sample(-1) with pytest.raises(ValueError): From 67fac90da77e04a0399bccd81d1cef3840e99c25 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Thu, 12 May 2022 14:21:57 +0800 Subject: [PATCH 26/32] Run the format script --- python/ray/data/dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d13f55d2c37a..787e47002231 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -660,17 +660,19 @@ def process_batch(batch): if strategy == 0: return random.sample(batch, math.ceil(len(batch) * fraction)) else: - # Picks the item if the weight generated for that item <= fraction + # Picks the item if weight generated for that item <= fraction return [batch[i] for i in range(len(batch)) if probs[i] <= fraction] if isinstance(batch, pa.Table): if strategy == 0: # Generate a mask to select random indices - indices = random.sample(range(len(batch)), math.ceil(len(batch) * fraction)) + indices = random.sample( + range(len(batch)), math.ceil(len(batch) * fraction) + ) mask = [True if i in indices else False for i in range(len(batch))] return batch.filter(mask) else: - # Lets the item pass if the weight generated for that item <= fraction + # Lets the item pass if weight generated for that item <= fraction mask = [True if p <= fraction else False for p in probs] return batch.filter(mask) if isinstance(batch, pd.DataFrame): From 18b90c5303d8ce9ef840fb56e450ddfeac966eab Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 13 May 2022 09:04:54 +0800 Subject: [PATCH 27/32] Resolve minor issues --- python/ray/data/tests/test_dataset.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 21e07f3e7c2b..fd21c0887d07 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3569,33 +3569,31 @@ def get_node_id(): def test_random_sample(ray_start_regular_shared): import math - def test(dataset, sample_percent=0.5): + def ensure_sample_size_close(dataset, sample_percent=0.5): r1 = ds.random_sample(sample_percent) assert math.isclose( r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2 ) ds = ray.data.range(10, parallelism=2) - test(ds) + ensure_sample_size_close(ds) ds = ray.data.range_arrow(10, parallelism=2) - test(ds) + ensure_sample_size_close(ds) ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2)) - test(ds) + ensure_sample_size_close(ds) - # "weird" datasets + # imbalanced datasets ds1 = ray.data.range(1, parallelism=1) ds2 = ray.data.range(2, parallelism=1) ds3 = ray.data.range(3, parallelism=1) # noinspection PyTypeChecker ds = ds1.union(ds2).union(ds3) - test(ds) - + ensure_sample_size_close(ds) # Small datasets - ds1 = ray.data.range(5, parallelism=5) - test(ds1) + ensure_sample_size_close(ds1) def test_random_sample_checks(ray_start_regular_shared): From c64d4b57b6d8d71d0b1cc920463502c5c10c3aa8 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 13 May 2022 09:10:52 +0800 Subject: [PATCH 28/32] Always use strategy = 1 --- python/ray/data/dataset.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 787e47002231..3d87c690c49f 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -649,32 +649,16 @@ def random_sample( random.seed(seed) def process_batch(batch): - strategy = 0 - probs = [] - if len(batch) * fraction < 1: - strategy = 1 - # Bernoulli sampler similar to Apache spark's sampler - probs = [random.random() for _ in range(len(batch))] + # Utilizes Bernoulli sampling, similar to Apache spark + probs = [random.random() for _ in range(len(batch))] if isinstance(batch, list): - if strategy == 0: - return random.sample(batch, math.ceil(len(batch) * fraction)) - else: - # Picks the item if weight generated for that item <= fraction - return [batch[i] for i in range(len(batch)) if probs[i] <= fraction] + return [batch[i] for i in range(len(batch)) if probs[i] <= fraction] if isinstance(batch, pa.Table): - if strategy == 0: - # Generate a mask to select random indices - indices = random.sample( - range(len(batch)), math.ceil(len(batch) * fraction) - ) - mask = [True if i in indices else False for i in range(len(batch))] - return batch.filter(mask) - else: - # Lets the item pass if weight generated for that item <= fraction - mask = [True if p <= fraction else False for p in probs] - return batch.filter(mask) + # Lets the item pass if weight generated for that item <= fraction + mask = [True if p <= fraction else False for p in probs] + return batch.filter(mask) if isinstance(batch, pd.DataFrame): return batch.sample(frac=fraction) From b1b45c54fe2a174406f156f17af38eeded6341e2 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 13 May 2022 09:11:50 +0800 Subject: [PATCH 29/32] Remove unused import math --- python/ray/data/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3d87c690c49f..0bdd5a15d774 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -635,7 +635,6 @@ def random_sample( Returns a Dataset containing the sampled elements. """ import random - import math import pyarrow as pa import pandas as pd From a4cbbde83e7349b8f878c545b26471ca4a62f16b Mon Sep 17 00:00:00 2001 From: Robert Date: Fri, 13 May 2022 10:17:51 +0800 Subject: [PATCH 30/32] Explain mask generation This utilizes more concise terminology for the generation of the mask Co-authored-by: Clark Zinzow --- python/ray/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0bdd5a15d774..8d83ab96acf0 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -656,7 +656,7 @@ def process_batch(batch): if isinstance(batch, pa.Table): # Lets the item pass if weight generated for that item <= fraction - mask = [True if p <= fraction else False for p in probs] + mask = [p <= fraction for p in probs] return batch.filter(mask) if isinstance(batch, pd.DataFrame): return batch.sample(frac=fraction) From a8fecf350d00d13e76f478899092d8302bfa63b3 Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Fri, 13 May 2022 11:06:52 +0800 Subject: [PATCH 31/32] Performance improvements for pyarrow --- python/ray/data/dataset.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8d83ab96acf0..8472cce4b23e 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -648,20 +648,16 @@ def random_sample( random.seed(seed) def process_batch(batch): - # Utilizes Bernoulli sampling, similar to Apache spark - probs = [random.random() for _ in range(len(batch))] - if isinstance(batch, list): - return [batch[i] for i in range(len(batch)) if probs[i] <= fraction] - + return [row for row in batch if random.random() <= fraction] if isinstance(batch, pa.Table): # Lets the item pass if weight generated for that item <= fraction - mask = [p <= fraction for p in probs] - return batch.filter(mask) + return batch.filter( + pa.array(random.random() <= fraction for _ in range(len(batch))) + ) if isinstance(batch, pd.DataFrame): return batch.sample(frac=fraction) - - raise ValueError("Unsupported batch type: {}".format(type(batch))) + raise ValueError(f"Unsupported batch type: {type(batch)}") return self.map_batches(process_batch) From 88d11db7d70a2b2054e6364b3c45bd6516bbefbc Mon Sep 17 00:00:00 2001 From: Robert Xiu Date: Tue, 17 May 2022 18:09:39 +0800 Subject: [PATCH 32/32] Fixes failing test: test_parquet_read_spread --- python/ray/data/tests/test_dataset.py | 84 +++++++++++++-------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index fd21c0887d07..b5fc9f67a0fc 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -3425,6 +3425,48 @@ def test_column_name_type_check(ray_start_regular_shared): ray.data.from_pandas(df) +def test_random_sample(ray_start_regular_shared): + import math + + def ensure_sample_size_close(dataset, sample_percent=0.5): + r1 = ds.random_sample(sample_percent) + assert math.isclose( + r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2 + ) + + ds = ray.data.range(10, parallelism=2) + ensure_sample_size_close(ds) + + ds = ray.data.range_arrow(10, parallelism=2) + ensure_sample_size_close(ds) + + ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2)) + ensure_sample_size_close(ds) + + # imbalanced datasets + ds1 = ray.data.range(1, parallelism=1) + ds2 = ray.data.range(2, parallelism=1) + ds3 = ray.data.range(3, parallelism=1) + # noinspection PyTypeChecker + ds = ds1.union(ds2).union(ds3) + ensure_sample_size_close(ds) + # Small datasets + ds1 = ray.data.range(5, parallelism=5) + ensure_sample_size_close(ds1) + + +def test_random_sample_checks(ray_start_regular_shared): + with pytest.raises(ValueError): + # Cannot sample -1 + ray.data.range(1).random_sample(-1) + with pytest.raises(ValueError): + # Cannot sample from empty dataset + ray.data.range(0).random_sample(0.2) + with pytest.raises(ValueError): + # Cannot sample fraction > 1 + ray.data.range(1).random_sample(10) + + @pytest.mark.parametrize("pipelined", [False, True]) @pytest.mark.parametrize("use_push_based_shuffle", [False, True]) def test_random_shuffle(shutdown_only, pipelined, use_push_based_shuffle): @@ -3566,48 +3608,6 @@ def get_node_id(): assert set(locations) == {node1_id, node2_id} -def test_random_sample(ray_start_regular_shared): - import math - - def ensure_sample_size_close(dataset, sample_percent=0.5): - r1 = ds.random_sample(sample_percent) - assert math.isclose( - r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2 - ) - - ds = ray.data.range(10, parallelism=2) - ensure_sample_size_close(ds) - - ds = ray.data.range_arrow(10, parallelism=2) - ensure_sample_size_close(ds) - - ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2)) - ensure_sample_size_close(ds) - - # imbalanced datasets - ds1 = ray.data.range(1, parallelism=1) - ds2 = ray.data.range(2, parallelism=1) - ds3 = ray.data.range(3, parallelism=1) - # noinspection PyTypeChecker - ds = ds1.union(ds2).union(ds3) - ensure_sample_size_close(ds) - # Small datasets - ds1 = ray.data.range(5, parallelism=5) - ensure_sample_size_close(ds1) - - -def test_random_sample_checks(ray_start_regular_shared): - with pytest.raises(ValueError): - # Cannot sample -1 - ray.data.range(1).random_sample(-1) - with pytest.raises(ValueError): - # Cannot sample from empty dataset - ray.data.range(0).random_sample(0.2) - with pytest.raises(ValueError): - # Cannot sample fraction > 1 - ray.data.range(1).random_sample(10) - - def test_parquet_read_spread(ray_start_cluster, tmp_path): cluster = ray_start_cluster cluster.add_node(