Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement random_sample() #24492

Merged
merged 42 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f0e40b8
Build random_sample feature (#24449)
bushshrub May 5, 2022
ade5959
Ran scripts/format.sh
bushshrub May 5, 2022
6c67d1c
Merge branch 'ray-project:master' into master
bushshrub May 5, 2022
d54f8ff
Make random_sample more random
bushshrub May 5, 2022
dcd8602
Merge remote-tracking branch 'origin/master'
bushshrub May 5, 2022
93f9dfb
Add some dataset validity checks
bushshrub May 5, 2022
098426c
Fix random_sample() for non-list types
bushshrub May 5, 2022
72da35d
Account for possibly attempting to sample n from a batch for x elemen…
bushshrub May 5, 2022
d76fb07
Run scripts/format.sh
bushshrub May 5, 2022
61791dd
Update sampling algorithm, updated documentation to explain sampling …
bushshrub May 6, 2022
c7b9f55
Run format script
bushshrub May 6, 2022
cfa9291
Add a random_shuffle on the sample_population since .take() will leav…
bushshrub May 6, 2022
52d3063
run format script
bushshrub May 6, 2022
b970866
Add tests for random sampling
bushshrub May 6, 2022
4d903af
Merge branch 'ray-project:master' into master
bushshrub May 6, 2022
aa57653
Merge remote-tracking branch 'origin/master'
bushshrub May 6, 2022
c31f88e
Add another sampling strategy
bushshrub May 6, 2022
e5b4bcc
Add tests for other dataset types
bushshrub May 6, 2022
e25efa9
Working but the tests may fail since random_sample does not guarantee…
bushshrub May 8, 2022
7d4ce8b
Merge branch 'ray-project:master' into master
bushshrub May 8, 2022
02f2724
Fix tests
bushshrub May 9, 2022
58b06ba
Merge branch 'ray-project:master' into master
bushshrub May 9, 2022
894114c
Undo IDE auto reformat
bushshrub May 10, 2022
422c92b
Merge remote-tracking branch 'origin/master'
bushshrub May 10, 2022
7911dd3
Update sampling as per comment
bushshrub May 10, 2022
088b291
Update dataset.py
ericl May 12, 2022
01d922f
Fix failing test
bushshrub May 12, 2022
ef4bcae
Fix failing test #2
bushshrub May 12, 2022
a6885d1
Merge branch 'ray-project:master' into master
bushshrub May 12, 2022
9aa8337
Break up ValueError assertions
bushshrub May 12, 2022
17e8e5c
Update documentation to reflect the number of items being returned
bushshrub May 12, 2022
46290c1
Add handling to address len(batch) * fraction < 1
bushshrub May 12, 2022
c9c6eb3
Test for 46290c1
bushshrub May 12, 2022
67fac90
Run the format script
bushshrub May 12, 2022
18b90c5
Resolve minor issues
bushshrub May 13, 2022
c64d4b5
Always use strategy = 1
bushshrub May 13, 2022
b1b45c5
Remove unused import math
bushshrub May 13, 2022
f772686
Merge branch 'ray-project:master' into master
bushshrub May 13, 2022
a4cbbde
Explain mask generation
bushshrub May 13, 2022
a8fecf3
Performance improvements for pyarrow
bushshrub May 13, 2022
4777439
Merge branch 'ray-project:master' into master
bushshrub May 17, 2022
88d11db
Fixes failing test: test_parquet_read_spread
bushshrub May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,54 @@ def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
)
return Dataset(plan, self._epoch, self._lazy)

def random_sample(
self, fraction: float, *, seed: Optional[int] = None
) -> "Dataset[T]":
"""Randomly samples a fraction of the elements of this dataset.

Note that the exact number of elements returned is not guaranteed,
and that the number of elements being returned is roughly fraction * total_rows.

Examples:
>>> import ray
>>> ds = ray.data.range(100) # doctest: +SKIP
>>> ds.random_sample(0.1) # doctest: +SKIP
>>> ds.random_sample(0.2, seed=12345) # doctest: +SKIP

Args:
fraction: The fraction of elements to sample.
seed: Seeds the python random pRNG generator.

Returns:
Returns a Dataset containing the sampled elements.
"""
import random
import pyarrow as pa
import pandas as pd

if self.num_blocks() == 0:
raise ValueError("Cannot sample from an empty dataset.")

if fraction < 0 or fraction > 1:
raise ValueError("Fraction must be between 0 and 1.")

if seed:
random.seed(seed)

def process_batch(batch):
bushshrub marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(batch, list):
return [row for row in batch if random.random() <= fraction]
if isinstance(batch, pa.Table):
# Lets the item pass if weight generated for that item <= fraction
return batch.filter(
pa.array(random.random() <= fraction for _ in range(len(batch)))
)
if isinstance(batch, pd.DataFrame):
return batch.sample(frac=fraction)
raise ValueError(f"Unsupported batch type: {type(batch)}")

return self.map_batches(process_batch)

def split(
self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None
) -> List["Dataset[T]"]:
Expand Down
42 changes: 42 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3425,6 +3425,48 @@ def test_column_name_type_check(ray_start_regular_shared):
ray.data.from_pandas(df)


def test_random_sample(ray_start_regular_shared):
import math

def ensure_sample_size_close(dataset, sample_percent=0.5):
r1 = ds.random_sample(sample_percent)
assert math.isclose(
r1.count(), int(ds.count() * sample_percent), rel_tol=2, abs_tol=2
)

ds = ray.data.range(10, parallelism=2)
ensure_sample_size_close(ds)

ds = ray.data.range_arrow(10, parallelism=2)
ensure_sample_size_close(ds)

ds = ray.data.range_tensor(5, parallelism=2, shape=(2, 2))
ensure_sample_size_close(ds)

# imbalanced datasets
ds1 = ray.data.range(1, parallelism=1)
ds2 = ray.data.range(2, parallelism=1)
ds3 = ray.data.range(3, parallelism=1)
# noinspection PyTypeChecker
ds = ds1.union(ds2).union(ds3)
ensure_sample_size_close(ds)
# Small datasets
ds1 = ray.data.range(5, parallelism=5)
ensure_sample_size_close(ds1)


def test_random_sample_checks(ray_start_regular_shared):
with pytest.raises(ValueError):
# Cannot sample -1
ray.data.range(1).random_sample(-1)
with pytest.raises(ValueError):
# Cannot sample from empty dataset
ray.data.range(0).random_sample(0.2)
with pytest.raises(ValueError):
# Cannot sample fraction > 1
ray.data.range(1).random_sample(10)


@pytest.mark.parametrize("pipelined", [False, True])
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
def test_random_shuffle(shutdown_only, pipelined, use_push_based_shuffle):
Expand Down