Skip to content

Commit

Permalink
[Data] Add support for using actors with map_groups (#45310)
Browse files Browse the repository at this point in the history
See #41406

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored May 14, 2024
1 parent 2fa6663 commit b7d6d18
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 8 deletions.
43 changes: 35 additions & 8 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from ray.data._internal.compute import ComputeStrategy
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
from ray.data.aggregate import AggregateFn, Count, Max, Mean, Min, Std, Sum
from ray.data.block import BlockAccessor, UserDefinedFunction
from ray.data.block import BlockAccessor, CallableClass, UserDefinedFunction
from ray.data.dataset import DataBatch, Dataset
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -109,8 +109,11 @@ def map_groups(
batch_format: Optional[str] = "default",
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
**ray_remote_args,
) -> "Dataset":
"""Apply the given function to each group of records of this dataset.
Expand Down Expand Up @@ -164,6 +167,12 @@ def map_groups(
exactly as is with no additional formatting.
fn_args: Arguments to `fn`.
fn_kwargs: Keyword arguments to `fn`.
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
You can only provide this if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
This can only be provided if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
num_cpus: The number of CPUs to reserve for each parallel map worker.
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
Expand Down Expand Up @@ -206,7 +215,7 @@ def get_key_boundaries(block_accessor: BlockAccessor) -> List[int]:

# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def group_fn(batch, *args, **kwargs):
def apply_udf_to_groups(udf, batch, *args, **kwargs):
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)
if self._key:
Expand All @@ -221,25 +230,43 @@ def group_fn(batch, *args, **kwargs):
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = fn(group_batch, *args, **kwargs)
applied = udf(group_batch, *args, **kwargs)
yield applied
start = end

if isinstance(fn, CallableClass):

class wrapped_fn:
def __init__(self, *args, **kwargs):
self.fn = fn(*args, **kwargs)

def __call__(self, batch, *args, **kwargs):
yield from apply_udf_to_groups(self.fn, batch, *args, **kwargs)

else:

def wrapped_fn(batch, *args, **kwargs):
yield from apply_udf_to_groups(fn, batch, *args, **kwargs)

# Change the name of the wrapped function so that users see the name of their
# function rather than `wrapped_fn` in the progress bar.
wrapped_fn.__name__ = fn.__name__

# Note we set batch_size=None here, so it will use the entire block as a batch,
# which ensures that each group will be contained within a batch in entirety.
return sorted_ds._map_batches_without_batch_size_validation(
group_fn,
wrapped_fn,
batch_size=None,
compute=compute,
batch_format=batch_format,
zero_copy_batch=False,
fn_args=fn_args,
fn_kwargs=fn_kwargs,
fn_constructor_args=None,
fn_constructor_kwargs=None,
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
num_cpus=num_cpus,
num_gpus=num_gpus,
concurrency=None,
concurrency=concurrency,
ray_remote_args_fn=None,
**ray_remote_args,
)
Expand Down
39 changes: 39 additions & 0 deletions python/ray/data/tests/test_all_to_all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import random
import time
from typing import Optional
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -149,6 +150,44 @@ def test_map_groups_with_gpus(shutdown_only):
assert rows == [{"id": 0}]


def test_map_groups_with_actors(ray_start_regular_shared):
class Identity:
def __call__(self, batch):
return batch

rows = (
ray.data.range(1).groupby("id").map_groups(Identity, concurrency=1).take_all()
)

assert rows == [{"id": 0}]


def test_map_groups_with_actors_and_args(ray_start_regular_shared):
class Fn:
def __init__(self, x: int, y: Optional[int] = None):
self.x = x
self.y = y

def __call__(self, batch, q: int, r: Optional[int] = None):
return {"x": [self.x], "y": [self.y], "q": [q], "r": [r]}

rows = (
ray.data.range(1)
.groupby("id")
.map_groups(
Fn,
concurrency=1,
fn_constructor_args=[0],
fn_constructor_kwargs={"y": 1},
fn_args=[2],
fn_kwargs={"r": 3},
)
.take_all()
)

assert rows == [{"x": 0, "y": 1, "q": 2, "r": 3}]


def test_groupby_large_udf_returns(ray_start_regular_shared):
# Test for https://github.com/ray-project/ray/issues/44861.

Expand Down

0 comments on commit b7d6d18

Please sign in to comment.