From b48093775d6be82aa34cedfbd500f6459a362ad5 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 10 Apr 2023 13:23:00 -0700 Subject: [PATCH] [Data] Improve state initialization for `ActorPoolMapOperator` (#34037) ActorPoolMapOperator takes in a Callable class which initializes some state to be reused for every batch. In the current implementation, this state is initialized on the first batch, rather than during actor init. In this PR, we separate the state initialization and actually call it during Actor init. This allows state to be initialized for fixed size actor pools, even when tasks are not ready to be dispatched for better pipelining. It also supports using multithreaded actors, so state gets initialized once per actor instead of once per thread. --------- Signed-off-by: amogkam Signed-off-by: elliottower --- .../data/_internal/execution/legacy_compat.py | 16 ++++++------ .../operators/actor_pool_map_operator.py | 14 +++++++--- .../execution/operators/map_operator.py | 18 +++++++++++-- .../data/_internal/planner/plan_udf_map_op.py | 15 +++++------ python/ray/data/tests/test_operators.py | 26 +++++++++++++++++++ 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 8179973631d4..94d633b57432 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -259,25 +259,24 @@ def _stage_to_operator(stage: Stage, input_op: PhysicalOperator) -> PhysicalOper fn_ = stage.fn def fn(item: Any) -> Any: - # Wrapper providing cached instantiation of stateful callable class - # UDFs. + assert ray.data._cached_fn is not None + assert ray.data._cached_cls == fn_ + return ray.data._cached_fn(item) + + def init_fn(): if ray.data._cached_fn is None: ray.data._cached_cls = fn_ ray.data._cached_fn = fn_( *fn_constructor_args, **fn_constructor_kwargs ) - else: - # A worker is destroyed when its actor is killed, so we - # shouldn't have any worker reuse across different UDF - # applications (i.e. different map operators). - assert ray.data._cached_cls == fn_ - return ray.data._cached_fn(item) else: fn = stage.fn + init_fn = None fn_args = (fn,) else: fn_args = () + init_fn = None if stage.fn_args: fn_args += stage.fn_args fn_kwargs = stage.fn_kwargs or {} @@ -288,6 +287,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: return MapOperator.create( do_map, input_op, + init_fn=init_fn, name=stage.name, compute_strategy=compute, min_rows_per_bundle=stage.target_block_size, diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 4451769b0e82..e07b0c2d02cf 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -3,7 +3,7 @@ from typing import Dict, Any, Iterator, Callable, List, Tuple, Union, Optional import ray -from ray.data.block import Block, BlockMetadata +from ray.data.block import Block, BlockMetadata, _CallableClassProtocol from ray.data.context import DatasetContext, DEFAULT_SCHEDULING_STRATEGY from ray.data._internal.compute import ActorPoolStrategy from ray.data._internal.dataset_logger import DatasetLogger @@ -37,6 +37,7 @@ class ActorPoolMapOperator(MapOperator): def __init__( self, transform_fn: Callable[[Iterator[Block]], Iterator[Block]], + init_fn: Callable[[], None], input_op: PhysicalOperator, autoscaling_policy: "AutoscalingPolicy", name: str = "ActorPoolMap", @@ -47,6 +48,7 @@ def __init__( Args: transform_fn: The function to apply to each ref bundle input. + init_fn: The callable class to instantiate on each actor. input_op: Operator generating input data for this op. autoscaling_policy: A policy controlling when the actor pool should be scaled up and scaled down. @@ -60,6 +62,7 @@ def __init__( super().__init__( transform_fn, input_op, name, min_rows_per_bundle, ray_remote_args ) + self._init_fn = init_fn self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args) # Create autoscaling policy from compute strategy. @@ -105,7 +108,7 @@ def _start_actor(self): """Start a new actor and add it to the actor pool as a pending actor.""" assert self._cls is not None ctx = DatasetContext.get_current() - actor = self._cls.remote(ctx, src_fn_name=self.name) + actor = self._cls.remote(ctx, src_fn_name=self.name, init_fn=self._init_fn) self._actor_pool.add_pending_actor(actor, actor.get_location.remote()) def _add_bundled_input(self, bundle: RefBundle): @@ -307,10 +310,15 @@ def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any class _MapWorker: """An actor worker for MapOperator.""" - def __init__(self, ctx: DatasetContext, src_fn_name: str): + def __init__( + self, ctx: DatasetContext, src_fn_name: str, init_fn: _CallableClassProtocol + ): DatasetContext._set_current(ctx) self.src_fn_name: str = src_fn_name + # Initialize state for this actor. + init_fn() + def get_location(self) -> NodeIdStr: return ray.get_runtime_context().get_node_id() diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 6115bdf02f53..a50b6e86acc3 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -2,10 +2,15 @@ import copy from dataclasses import dataclass import itertools -from typing import List, Iterator, Any, Dict, Optional, Union +from typing import Callable, List, Iterator, Any, Dict, Optional, Union import ray -from ray.data.block import Block, BlockAccessor, BlockMetadata, BlockExecStats +from ray.data.block import ( + Block, + BlockAccessor, + BlockMetadata, + BlockExecStats, +) from ray.data._internal.compute import ( ComputeStrategy, TaskPoolStrategy, @@ -66,6 +71,7 @@ def create( cls, transform_fn: MapTransformFn, input_op: PhysicalOperator, + init_fn: Optional[Callable[[], None]] = None, name: str = "Map", # TODO(ekl): slim down ComputeStrategy to only specify the compute # config and not contain implementation code. @@ -83,6 +89,7 @@ def create( Args: transform_fn: The function to apply to each ref bundle input. input_op: Operator generating input data for this op. + init_fn: The callable class to instantiate if using ActorPoolMapOperator. name: The name of this operator. compute_strategy: Customize the compute strategy for this op. min_rows_per_bundle: The number of rows to gather per batch passed to the @@ -117,8 +124,15 @@ def create( compute_strategy ) autoscaling_policy = AutoscalingPolicy(autoscaling_config) + + if init_fn is None: + + def init_fn(): + pass + return ActorPoolMapOperator( transform_fn, + init_fn, input_op, autoscaling_policy=autoscaling_policy, name=name, diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index f7ede0231f6c..418ed83d2f88 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -61,20 +61,18 @@ def _plan_udf_map_op( fn_ = op._fn def fn(item: Any) -> Any: - # Wrapper providing cached instantiation of stateful callable class - # UDFs. + assert ray.data._cached_fn is not None + assert ray.data._cached_cls == fn_ + return ray.data._cached_fn(item) + + def init_fn(): if ray.data._cached_fn is None: ray.data._cached_cls = fn_ ray.data._cached_fn = fn_(*fn_constructor_args, **fn_constructor_kwargs) - else: - # A worker is destroyed when its actor is killed, so we - # shouldn't have any worker reuse across different UDF - # applications (i.e. different map operators). - assert ray.data._cached_cls == fn_ - return ray.data._cached_fn(item) else: fn = op._fn + init_fn = None fn_args = (fn,) if op._fn_args: fn_args += op._fn_args @@ -86,6 +84,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: return MapOperator.create( do_map, input_physical_dag, + init_fn=init_fn, name=op.name, compute_strategy=compute, min_rows_per_bundle=op._target_block_size, diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index 64e7c020a1a0..e238213aa3c1 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -506,6 +506,32 @@ def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]: wait_for_condition(lambda: (ray.available_resources().get("GPU", 0) == 1.0)) +def test_actor_pool_map_operator_init(ray_start_regular_shared): + """Tests that ActorPoolMapOperator runs init_fn on start.""" + + from ray.exceptions import RayActorError + + def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]: + time.sleep(999) + + def _fail(): + raise ValueError("init_failed") + + input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(10)])) + compute_strategy = ActorPoolStrategy(min_size=1) + + op = MapOperator.create( + _sleep, + input_op=input_op, + init_fn=_fail, + name="TestMapper", + compute_strategy=compute_strategy, + ) + + with pytest.raises(RayActorError, match=r"init_failed"): + op.start(ExecutionOptions()) + + @pytest.mark.parametrize( "compute,expected", [