Skip to content

Commit

Permalink
[Data] Improve state initialization for ActorPoolMapOperator (ray-p…
Browse files Browse the repository at this point in the history
…roject#34037)

ActorPoolMapOperator takes in a Callable class which initializes some state to be reused for every batch.

In the current implementation, this state is initialized on the first batch, rather than during actor init.

In this PR, we separate the state initialization and actually call it during Actor init. This allows state to be initialized for fixed size actor pools, even when tasks are not ready to be dispatched for better pipelining. It also supports using multithreaded actors, so state gets initialized once per actor instead of once per thread.

---------

Signed-off-by: amogkam <[email protected]>
Signed-off-by: elliottower <[email protected]>
  • Loading branch information
amogkam authored and elliottower committed Apr 22, 2023
1 parent b9df7dc commit b480937
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 21 deletions.
16 changes: 8 additions & 8 deletions python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,24 @@ def _stage_to_operator(stage: Stage, input_op: PhysicalOperator) -> PhysicalOper
fn_ = stage.fn

def fn(item: Any) -> Any:
# Wrapper providing cached instantiation of stateful callable class
# UDFs.
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == fn_
return ray.data._cached_fn(item)

def init_fn():
if ray.data._cached_fn is None:
ray.data._cached_cls = fn_
ray.data._cached_fn = fn_(
*fn_constructor_args, **fn_constructor_kwargs
)
else:
# A worker is destroyed when its actor is killed, so we
# shouldn't have any worker reuse across different UDF
# applications (i.e. different map operators).
assert ray.data._cached_cls == fn_
return ray.data._cached_fn(item)

else:
fn = stage.fn
init_fn = None
fn_args = (fn,)
else:
fn_args = ()
init_fn = None
if stage.fn_args:
fn_args += stage.fn_args
fn_kwargs = stage.fn_kwargs or {}
Expand All @@ -288,6 +287,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
return MapOperator.create(
do_map,
input_op,
init_fn=init_fn,
name=stage.name,
compute_strategy=compute,
min_rows_per_bundle=stage.target_block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Any, Iterator, Callable, List, Tuple, Union, Optional

import ray
from ray.data.block import Block, BlockMetadata
from ray.data.block import Block, BlockMetadata, _CallableClassProtocol
from ray.data.context import DatasetContext, DEFAULT_SCHEDULING_STRATEGY
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.dataset_logger import DatasetLogger
Expand Down Expand Up @@ -37,6 +37,7 @@ class ActorPoolMapOperator(MapOperator):
def __init__(
self,
transform_fn: Callable[[Iterator[Block]], Iterator[Block]],
init_fn: Callable[[], None],
input_op: PhysicalOperator,
autoscaling_policy: "AutoscalingPolicy",
name: str = "ActorPoolMap",
Expand All @@ -47,6 +48,7 @@ def __init__(
Args:
transform_fn: The function to apply to each ref bundle input.
init_fn: The callable class to instantiate on each actor.
input_op: Operator generating input data for this op.
autoscaling_policy: A policy controlling when the actor pool should be
scaled up and scaled down.
Expand All @@ -60,6 +62,7 @@ def __init__(
super().__init__(
transform_fn, input_op, name, min_rows_per_bundle, ray_remote_args
)
self._init_fn = init_fn
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)

# Create autoscaling policy from compute strategy.
Expand Down Expand Up @@ -105,7 +108,7 @@ def _start_actor(self):
"""Start a new actor and add it to the actor pool as a pending actor."""
assert self._cls is not None
ctx = DatasetContext.get_current()
actor = self._cls.remote(ctx, src_fn_name=self.name)
actor = self._cls.remote(ctx, src_fn_name=self.name, init_fn=self._init_fn)
self._actor_pool.add_pending_actor(actor, actor.get_location.remote())

def _add_bundled_input(self, bundle: RefBundle):
Expand Down Expand Up @@ -307,10 +310,15 @@ def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any
class _MapWorker:
"""An actor worker for MapOperator."""

def __init__(self, ctx: DatasetContext, src_fn_name: str):
def __init__(
self, ctx: DatasetContext, src_fn_name: str, init_fn: _CallableClassProtocol
):
DatasetContext._set_current(ctx)
self.src_fn_name: str = src_fn_name

# Initialize state for this actor.
init_fn()

def get_location(self) -> NodeIdStr:
return ray.get_runtime_context().get_node_id()

Expand Down
18 changes: 16 additions & 2 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import copy
from dataclasses import dataclass
import itertools
from typing import List, Iterator, Any, Dict, Optional, Union
from typing import Callable, List, Iterator, Any, Dict, Optional, Union

import ray
from ray.data.block import Block, BlockAccessor, BlockMetadata, BlockExecStats
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
BlockExecStats,
)
from ray.data._internal.compute import (
ComputeStrategy,
TaskPoolStrategy,
Expand Down Expand Up @@ -66,6 +71,7 @@ def create(
cls,
transform_fn: MapTransformFn,
input_op: PhysicalOperator,
init_fn: Optional[Callable[[], None]] = None,
name: str = "Map",
# TODO(ekl): slim down ComputeStrategy to only specify the compute
# config and not contain implementation code.
Expand All @@ -83,6 +89,7 @@ def create(
Args:
transform_fn: The function to apply to each ref bundle input.
input_op: Operator generating input data for this op.
init_fn: The callable class to instantiate if using ActorPoolMapOperator.
name: The name of this operator.
compute_strategy: Customize the compute strategy for this op.
min_rows_per_bundle: The number of rows to gather per batch passed to the
Expand Down Expand Up @@ -117,8 +124,15 @@ def create(
compute_strategy
)
autoscaling_policy = AutoscalingPolicy(autoscaling_config)

if init_fn is None:

def init_fn():
pass

return ActorPoolMapOperator(
transform_fn,
init_fn,
input_op,
autoscaling_policy=autoscaling_policy,
name=name,
Expand Down
15 changes: 7 additions & 8 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,18 @@ def _plan_udf_map_op(
fn_ = op._fn

def fn(item: Any) -> Any:
# Wrapper providing cached instantiation of stateful callable class
# UDFs.
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == fn_
return ray.data._cached_fn(item)

def init_fn():
if ray.data._cached_fn is None:
ray.data._cached_cls = fn_
ray.data._cached_fn = fn_(*fn_constructor_args, **fn_constructor_kwargs)
else:
# A worker is destroyed when its actor is killed, so we
# shouldn't have any worker reuse across different UDF
# applications (i.e. different map operators).
assert ray.data._cached_cls == fn_
return ray.data._cached_fn(item)

else:
fn = op._fn
init_fn = None
fn_args = (fn,)
if op._fn_args:
fn_args += op._fn_args
Expand All @@ -86,6 +84,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
return MapOperator.create(
do_map,
input_physical_dag,
init_fn=init_fn,
name=op.name,
compute_strategy=compute,
min_rows_per_bundle=op._target_block_size,
Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,32 @@ def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]:
wait_for_condition(lambda: (ray.available_resources().get("GPU", 0) == 1.0))


def test_actor_pool_map_operator_init(ray_start_regular_shared):
"""Tests that ActorPoolMapOperator runs init_fn on start."""

from ray.exceptions import RayActorError

def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]:
time.sleep(999)

def _fail():
raise ValueError("init_failed")

input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(10)]))
compute_strategy = ActorPoolStrategy(min_size=1)

op = MapOperator.create(
_sleep,
input_op=input_op,
init_fn=_fail,
name="TestMapper",
compute_strategy=compute_strategy,
)

with pytest.raises(RayActorError, match=r"init_failed"):
op.start(ExecutionOptions())


@pytest.mark.parametrize(
"compute,expected",
[
Expand Down

0 comments on commit b480937

Please sign in to comment.