diff --git a/python/ray/data/BUILD b/python/ray/data/BUILD index 079240296c75..f532dba22afc 100644 --- a/python/ray/data/BUILD +++ b/python/ray/data/BUILD @@ -579,6 +579,14 @@ py_test( deps = ["//:ray_lib", ":conftest"], ) +py_test( + name = "test_autoscaler", + size = "small", + srcs = ["tests/test_autoscaler.py"], + tags = ["team:data", "exclusive"], + deps = ["//:ray_lib", ":conftest"], +) + py_test( name = "test_lance", size = "small", diff --git a/python/ray/data/_internal/execution/autoscaler/__init__.py b/python/ray/data/_internal/execution/autoscaler/__init__.py new file mode 100644 index 000000000000..c167c14fa1f3 --- /dev/null +++ b/python/ray/data/_internal/execution/autoscaler/__init__.py @@ -0,0 +1,15 @@ +from .autoscaler import Autoscaler +from .autoscaling_actor_pool import AutoscalingActorPool +from .default_autoscaler import DefaultAutoscaler + + +def create_autoscaler(topology, resource_manager, execution_id): + return DefaultAutoscaler(topology, resource_manager, execution_id) + + +__all__ = [ + "Autoscaler", + "DefaultAutoscaler", + "create_autoscaler", + "AutoscalingActorPool", +] diff --git a/python/ray/data/_internal/execution/autoscaler/autoscaler.py b/python/ray/data/_internal/execution/autoscaler/autoscaler.py new file mode 100644 index 000000000000..0e50b3a704ed --- /dev/null +++ b/python/ray/data/_internal/execution/autoscaler/autoscaler.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + from ray.data._internal.execution.resource_manager import ResourceManager + from ray.data._internal.execution.streaming_executor_state import Topology + + +@DeveloperAPI +class Autoscaler(ABC): + """Abstract interface for Ray Data autoscaler.""" + + def __init__( + self, + topology: "Topology", + resource_manager: "ResourceManager", + execution_id: str, + ): + self._topology = topology + self._resource_manager = resource_manager + self._execution_id = execution_id + + @abstractmethod + def try_trigger_scaling(self): + """Try trigger autoscaling. + + This method will be called each time when StreamingExecutor makes + a scheduling decision. A subclass should override this method to + handle the autoscaling of both the cluster and `AutoscalingActorPool`s. + """ + ... + + @abstractmethod + def on_executor_shutdown(self): + """Callback when the StreamingExecutor is shutting down.""" + ... diff --git a/python/ray/data/_internal/execution/autoscaler/autoscaling_actor_pool.py b/python/ray/data/_internal/execution/autoscaler/autoscaling_actor_pool.py new file mode 100644 index 000000000000..5fa417d34c6d --- /dev/null +++ b/python/ray/data/_internal/execution/autoscaler/autoscaling_actor_pool.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod + +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class AutoscalingActorPool(ABC): + """Abstract interface of an autoscaling actor pool. + + A `PhysicalOperator` can manage one or more `AutoscalingActorPool`s. + `Autoscaler` is responsible for deciding autoscaling of these actor + pools. + """ + + @abstractmethod + def min_size(self) -> int: + """Min size of the actor pool.""" + ... + + @abstractmethod + def max_size(self) -> int: + """Max size of the actor pool.""" + ... + + @abstractmethod + def current_size(self) -> int: + """Current size of the actor pool.""" + ... + + @abstractmethod + def num_running_actors(self) -> int: + """Number of running actors.""" + ... + + @abstractmethod + def num_active_actors(self) -> int: + """Number of actors with at least one active task.""" + ... + + @abstractmethod + def num_pending_actors(self) -> int: + """Number of actors pending creation.""" + ... + + @abstractmethod + def max_tasks_in_flight_per_actor(self) -> int: + """Max number of in-flight tasks per actor.""" + ... + + @abstractmethod + def current_in_flight_tasks(self) -> int: + """Number of current in-flight tasks.""" + ... + + def num_total_task_slots(self) -> int: + """Total number of task slots.""" + return self.max_tasks_in_flight_per_actor() * self.current_size() + + def num_free_task_slots(self) -> int: + """Number of free slots to run tasks.""" + return ( + self.max_tasks_in_flight_per_actor() * self.current_size() + - self.current_in_flight_tasks() + ) + + @abstractmethod + def scale_up(self, num_actors: int) -> int: + """Request the actor pool to scale up by the given number of actors. + + The number of actually added actors may be less than the requested + number. + + Returns: + The number of actors actually added. + """ + ... + + @abstractmethod + def scale_down(self, num_actors: int) -> int: + """Request actor pool to scale down by the given number of actors. + + The number of actually removed actors may be less than the requested + number. + + Returns: + The number of actors actually removed. + """ + ... diff --git a/python/ray/data/_internal/execution/autoscaler/default_autoscaler.py b/python/ray/data/_internal/execution/autoscaler/default_autoscaler.py new file mode 100644 index 000000000000..71a2346c96eb --- /dev/null +++ b/python/ray/data/_internal/execution/autoscaler/default_autoscaler.py @@ -0,0 +1,184 @@ +import math +import time +from typing import TYPE_CHECKING, Dict + +from .autoscaler import Autoscaler +from .autoscaling_actor_pool import AutoscalingActorPool +from ray.data._internal.execution.autoscaling_requester import ( + get_or_create_autoscaling_requester_actor, +) +from ray.data._internal.execution.interfaces.execution_options import ExecutionResources + +if TYPE_CHECKING: + from ray.data._internal.execution.interfaces import PhysicalOperator + from ray.data._internal.execution.resource_manager import ResourceManager + from ray.data._internal.execution.streaming_executor_state import OpState, Topology + + +class DefaultAutoscaler(Autoscaler): + + # Default threshold of actor pool utilization to trigger scaling up. + DEFAULT_ACTOR_POOL_SCALING_UP_THRESHOLD: float = 0.8 + # Default threshold of actor pool utilization to trigger scaling down. + DEFAULT_ACTOR_POOL_SCALING_DOWN_THRESHOLD: float = 0.5 + + # Min number of seconds between two autoscaling requests. + MIN_GAP_BETWEEN_AUTOSCALING_REQUESTS = 20 + + def __init__( + self, + topology: "Topology", + resource_manager: "ResourceManager", + execution_id: str, + actor_pool_scaling_up_threshold: float = DEFAULT_ACTOR_POOL_SCALING_UP_THRESHOLD, # noqa: E501 + actor_pool_scaling_down_threshold: float = DEFAULT_ACTOR_POOL_SCALING_DOWN_THRESHOLD, # noqa: E501 + ): + self._actor_pool_scaling_up_threshold = actor_pool_scaling_up_threshold + self._actor_pool_scaling_down_threshold = actor_pool_scaling_down_threshold + # Last time when a request was sent to Ray's autoscaler. + self._last_request_time = 0 + super().__init__(topology, resource_manager, execution_id) + + def try_trigger_scaling(self): + self._try_scale_up_cluster() + self._try_scale_up_or_down_actor_pool() + + def _calculate_actor_pool_util(self, actor_pool: AutoscalingActorPool): + """Calculate the utilization of the given actor pool.""" + if actor_pool.current_size() == 0: + return 0 + else: + return actor_pool.num_active_actors() / actor_pool.current_size() + + def _actor_pool_should_scale_up( + self, + actor_pool: AutoscalingActorPool, + op: "PhysicalOperator", + op_state: "OpState", + ): + # Do not scale up, if the op is completed or no more inputs are coming. + if op.completed() or (op._inputs_complete and op.internal_queue_size() == 0): + return False + if actor_pool.current_size() < actor_pool.min_size(): + # Scale up, if the actor pool is below min size. + return True + elif actor_pool.current_size() >= actor_pool.max_size(): + # Do not scale up, if the actor pool is already at max size. + return False + # Do not scale up, if the op still has enough resources to run. + if op_state._scheduling_status.under_resource_limits: + return False + # Do not scale up, if the op has enough free slots for the existing inputs. + if op_state.num_queued() <= actor_pool.num_free_task_slots(): + return False + # Determine whether to scale up based on the actor pool utilization. + util = self._calculate_actor_pool_util(actor_pool) + return util > self._actor_pool_scaling_up_threshold + + def _actor_pool_should_scale_down( + self, + actor_pool: AutoscalingActorPool, + op: "PhysicalOperator", + ): + # Scale down, if the op is completed or no more inputs are coming. + if op.completed() or (op._inputs_complete and op.internal_queue_size() == 0): + return True + if actor_pool.current_size() > actor_pool.max_size(): + # Scale down, if the actor pool is above max size. + return True + elif actor_pool.current_size() <= actor_pool.min_size(): + # Do not scale down, if the actor pool is already at min size. + return False + # Determine whether to scale down based on the actor pool utilization. + util = self._calculate_actor_pool_util(actor_pool) + return util < self._actor_pool_scaling_down_threshold + + def _try_scale_up_or_down_actor_pool(self): + for op, state in self._topology.items(): + actor_pools = op.get_autoscaling_actor_pools() + for actor_pool in actor_pools: + while True: + # Try to scale up or down the actor pool. + should_scale_up = self._actor_pool_should_scale_up( + actor_pool, + op, + state, + ) + should_scale_down = self._actor_pool_should_scale_down( + actor_pool, op + ) + if should_scale_up and not should_scale_down: + if actor_pool.scale_up(1) == 0: + break + elif should_scale_down and not should_scale_up: + if actor_pool.scale_down(1) == 0: + break + else: + break + + def _try_scale_up_cluster(self): + """Try to scale up the cluster to accomodate the provided in-progress workload. + + This makes a resource request to Ray's autoscaler consisting of the current, + aggregate usage of all operators in the DAG + the incremental usage of all + operators that are ready for dispatch (i.e. that have inputs queued). If the + autoscaler were to grant this resource request, it would allow us to dispatch + one task for every ready operator. + + Note that this resource request does not take the global resource limits or the + liveness policy into account; it only tries to make the existing resource usage + + one more task per ready operator feasible in the cluster. + """ + # Limit the frequency of autoscaling requests. + now = time.time() + if now - self._last_request_time < self.MIN_GAP_BETWEEN_AUTOSCALING_REQUESTS: + return + + # Scale up the cluster, if no ops are allowed to run, but there are still data + # in the input queues. + no_runnable_op = all( + op_state._scheduling_status.runnable is False + for _, op_state in self._topology.items() + ) + any_has_input = any( + op_state.num_queued() > 0 for _, op_state in self._topology.items() + ) + if not (no_runnable_op and any_has_input): + return + + self._last_request_time = now + + # Get resource usage for all ops + additional resources needed to launch one + # more task for each ready op. + resource_request = [] + + def to_bundle(resource: ExecutionResources) -> Dict: + req = {} + if resource.cpu: + req["CPU"] = math.ceil(resource.cpu) + if resource.gpu: + req["GPU"] = math.ceil(resource.gpu) + return req + + for op, state in self._topology.items(): + per_task_resource = op.incremental_resource_usage() + task_bundle = to_bundle(per_task_resource) + resource_request.extend([task_bundle] * op.num_active_tasks()) + # Only include incremental resource usage for ops that are ready for + # dispatch. + if state.num_queued() > 0: + # TODO(Clark): Scale up more aggressively by adding incremental resource + # usage for more than one bundle in the queue for this op? + resource_request.append(task_bundle) + + self._send_resource_request(resource_request) + + def _send_resource_request(self, resource_request): + # Make autoscaler resource request. + actor = get_or_create_autoscaling_requester_actor() + actor.request_resources.remote(resource_request, self._execution_id) + + def on_executor_shutdown(self): + # Make request for zero resources to autoscaler for this execution. + actor = get_or_create_autoscaling_requester_actor() + actor.request_resources.remote({}, self._execution_id) diff --git a/python/ray/data/_internal/execution/interfaces/physical_operator.py b/python/ray/data/_internal/execution/interfaces/physical_operator.py index 92de1bfc9356..739262ef2b80 100644 --- a/python/ray/data/_internal/execution/interfaces/physical_operator.py +++ b/python/ray/data/_internal/execution/interfaces/physical_operator.py @@ -4,6 +4,9 @@ import ray from .ref_bundle import RefBundle from ray._raylet import ObjectRefGenerator +from ray.data._internal.execution.autoscaler.autoscaling_actor_pool import ( + AutoscalingActorPool, +) from ray.data._internal.execution.interfaces.execution_options import ( ExecutionOptions, ExecutionResources, @@ -400,30 +403,14 @@ def base_resource_usage(self) -> ExecutionResources: """ return ExecutionResources() - def incremental_resource_usage( - self, consider_autoscaling=True - ) -> ExecutionResources: + def incremental_resource_usage(self) -> ExecutionResources: """Returns the incremental resources required for processing another input. For example, an operator that launches a task per input could return ExecutionResources(cpu=1) as its incremental usage. - - Args: - consider_autoscaling: Whether to consider the possibility of autoscaling. """ return ExecutionResources() - def notify_resource_usage( - self, input_queue_size: int, under_resource_limits: bool - ) -> None: - """Called periodically by the executor. - - Args: - input_queue_size: The number of inputs queued outside this operator. - under_resource_limits: Whether this operator is under resource limits. - """ - pass - def notify_in_task_submission_backpressure(self, in_backpressure: bool) -> None: """Called periodically from the executor to update internal in backpressure status for stats collection purposes. @@ -435,3 +422,7 @@ def notify_in_task_submission_backpressure(self, in_backpressure: bool) -> None: if self._in_task_submission_backpressure != in_backpressure: self._metrics.on_toggle_task_submission_backpressure(in_backpressure) self._in_task_submission_backpressure = in_backpressure + + def get_autoscaling_actor_pools(self) -> List[AutoscalingActorPool]: + """Return a list of `AutoscalingActorPool`s managed by this operator.""" + return [] diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 228385f4ea0f..4e267541e92a 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -1,10 +1,11 @@ import collections import logging -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import ray +from ray.actor import ActorHandle from ray.data._internal.compute import ActorPoolStrategy +from ray.data._internal.execution.autoscaler import AutoscalingActorPool from ray.data._internal.execution.interfaces import ( ExecutionOptions, ExecutionResources, @@ -51,7 +52,7 @@ def __init__( map_transformer: MapTransformer, input_op: PhysicalOperator, target_max_block_size: Optional[int], - autoscaling_policy: "AutoscalingPolicy", + compute_strategy: ActorPoolStrategy, name: str = "ActorPoolMap", min_rows_per_bundle: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, @@ -62,8 +63,7 @@ def __init__( transform_fn: The function to apply to each ref bundle input. init_fn: The callable class to instantiate on each actor. input_op: Operator generating input data for this op. - autoscaling_policy: A policy controlling when the actor pool should be - scaled up and scaled down. + compute_strategy: ComputeStrategy used for this operator. name: The name of this operator. target_max_block_size: The target maximum number of bytes to include in an output block. @@ -97,10 +97,7 @@ def __init__( ) self._min_rows_per_bundle = min_rows_per_bundle - # Create autoscaling policy from compute strategy. - self._autoscaling_policy = autoscaling_policy - # A pool of running actors on which we can execute mapper tasks. - self._actor_pool = _ActorPool(autoscaling_policy._config.max_tasks_in_flight) + self._actor_pool = _ActorPool(compute_strategy, self._start_actor) # A queue of bundles awaiting dispatch to actors. self._bundle_queue = collections.deque() # Cached actor class. @@ -117,8 +114,7 @@ def start(self, options: ExecutionOptions): # Create the actor workers and add them to the pool. self._cls = ray.remote(**self._ray_remote_args)(_MapWorker) - for _ in range(self._autoscaling_policy.min_workers): - self._start_actor() + self._actor_pool.scale_up(self._actor_pool.min_size()) refs = self._actor_pool.get_pending_actor_refs() # We synchronously wait for the initial number of actors to start. This avoids @@ -138,18 +134,6 @@ def start(self, options: ExecutionOptions): def should_add_input(self) -> bool: return self._actor_pool.num_free_slots() > 0 - # Called by streaming executor periodically to trigger autoscaling. - def notify_resource_usage( - self, input_queue_size: int, under_resource_limits: bool - ) -> None: - free_slots = self._actor_pool.num_free_slots() - if input_queue_size > free_slots and under_resource_limits: - # Try to scale up if work remains in the work queue. - self._scale_up_if_needed() - else: - # Try to remove any idle actors. - self._scale_down_if_needed() - def _start_actor(self): """Start a new actor and add it to the actor pool as a pending actor.""" assert self._cls is not None @@ -175,7 +159,7 @@ def _task_done_callback(res_ref): res_ref, lambda: _task_done_callback(res_ref), ) - self._actor_pool.add_pending_actor(actor, res_ref) + return actor, res_ref def _add_bundled_input(self, bundle: RefBundle): self._bundle_queue.append(bundle) @@ -229,39 +213,6 @@ def _task_done_callback(actor_to_return): lambda: _task_done_callback(actor_to_return), ) - # Needed in the bulk execution path for triggering autoscaling. This is a - # no-op in the streaming execution case. - if self._bundle_queue: - # Try to scale up if work remains in the work queue. - self._scale_up_if_needed() - else: - # Only try to scale down if the work queue has been fully consumed. - self._scale_down_if_needed() - - def _scale_up_if_needed(self): - """Try to scale up the pool if the autoscaling policy allows it.""" - while self._autoscaling_policy.should_scale_up( - num_total_workers=self._actor_pool.num_total_actors(), - num_running_workers=self._actor_pool.num_running_actors(), - ): - self._start_actor() - - def _scale_down_if_needed(self): - """Try to scale down the pool if the autoscaling policy allows it.""" - # Kill inactive workers if there's no more work to do. - self._kill_inactive_workers_if_done() - - while self._autoscaling_policy.should_scale_down( - num_total_workers=self._actor_pool.num_total_actors(), - num_idle_workers=self._actor_pool.num_idle_actors(), - ): - killed = self._actor_pool.kill_inactive_actor() - if not killed: - # This scaledown is best-effort, only killing an inactive worker if an - # inactive worker exists. If there are no inactive workers to kill, we - # break out of the scale-down loop. - break - def all_inputs_done(self): # Call base implementation to handle any leftover bundles. This may or may not # trigger task dispatch. @@ -271,15 +222,6 @@ def all_inputs_done(self): # once the bundle queue is exhausted. self._inputs_done = True - # Try to scale pool down. - self._scale_down_if_needed() - - def _kill_inactive_workers_if_done(self): - if self._inputs_done and not self._bundle_queue: - # No more tasks will be submitted, so we kill all current and future - # inactive workers. - self._actor_pool.kill_all_inactive_actors() - def shutdown(self): # We kill all actors in the pool on shutdown, even if they are busy doing work. self._actor_pool.kill_all_actors() @@ -288,7 +230,7 @@ def shutdown(self): # Warn if the user specified a batch or block size that prevents full # parallelization across the actor pool. We only know this information after # execution has completed. - min_workers = self._autoscaling_policy.min_workers + min_workers = self._actor_pool.min_size() if len(self._output_metadata) < min_workers: # The user created a stream that has too few blocks to begin with. logger.warning( @@ -312,7 +254,7 @@ def progress_str(self) -> str: return base def base_resource_usage(self) -> ExecutionResources: - min_workers = self._autoscaling_policy.min_workers + min_workers = self._actor_pool.min_size() return ExecutionResources( cpu=self._ray_remote_args.get("num_cpus", 0) * min_workers, gpu=self._ray_remote_args.get("num_gpus", 0) * min_workers, @@ -320,33 +262,18 @@ def base_resource_usage(self) -> ExecutionResources: def current_processor_usage(self) -> ExecutionResources: # Both pending and running actors count towards our current resource usage. - num_active_workers = self._actor_pool.num_total_actors() + num_active_workers = self._actor_pool.current_size() return ExecutionResources( cpu=self._ray_remote_args.get("num_cpus", 0) * num_active_workers, gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers, ) - def incremental_resource_usage( - self, consider_autoscaling=True - ) -> ExecutionResources: - # We would only have nonzero incremental CPU/GPU resources if a new task would - # require scale-up to run. - if consider_autoscaling and self._autoscaling_policy.should_scale_up( - num_total_workers=self._actor_pool.num_total_actors(), - num_running_workers=self._actor_pool.num_running_actors(), - ): - # A new task would trigger scale-up, so we include the actor resouce - # requests in the incremental resources. - num_cpus = self._ray_remote_args.get("num_cpus", 0) - num_gpus = self._ray_remote_args.get("num_gpus", 0) - else: - # A new task wouldn't trigger scale-up, so we consider the incremental - # compute resources to be 0. - num_cpus = 0 - num_gpus = 0 + def incremental_resource_usage(self) -> ExecutionResources: + # Submitting tasks to existing actors doesn't require additional + # CPU/GPU resources. return ExecutionResources( - cpu=num_cpus, - gpu=num_gpus, + cpu=0, + gpu=0, object_store_memory=self._metrics.obj_store_mem_max_pending_output_per_task or 0, ) @@ -377,6 +304,9 @@ def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any ray_remote_args["max_task_retries"] = -1 return ray_remote_args + def get_autoscaling_actor_pools(self) -> List[AutoscalingActorPool]: + return [self._actor_pool] + class _MapWorker: """An actor worker for MapOperator.""" @@ -413,136 +343,7 @@ def __repr__(self): return f"MapWorker({self.src_fn_name})" -# TODO(Clark): Promote this to a public config once we deprecate the legacy compute -# strategies. -@dataclass -class AutoscalingConfig: - """Configuration for an autoscaling actor pool.""" - - # Minimum number of workers in the actor pool. - min_workers: int - # Maximum number of workers in the actor pool. - max_workers: int - # Maximum number of tasks that can be in flight for a single worker. - # TODO(Clark): Have this informed by the prefetch_batches configuration, once async - # prefetching has been ported to this new actor pool. - max_tasks_in_flight: int = DEFAULT_MAX_TASKS_IN_FLIGHT - # Minimum ratio of ready workers to the total number of workers. If the pool is - # above this ratio, it will be allowed to be scaled up. - ready_to_total_workers_ratio: float = 0.8 - # Maximum ratio of idle workers to the total number of workers. If the pool goes - # above this ratio, the pool will be scaled down. - idle_to_total_workers_ratio: float = 0.5 - - def __post_init__(self): - if self.min_workers < 1: - raise ValueError("min_workers must be >= 1, got: ", self.min_workers) - if self.max_workers is not None and self.min_workers > self.max_workers: - raise ValueError( - "min_workers must be <= max_workers, got: ", - self.min_workers, - self.max_workers, - ) - if self.max_tasks_in_flight < 1: - raise ValueError( - "max_tasks_in_flight must be >= 1, got: ", - self.max_tasks_in_flight, - ) - - @classmethod - def from_compute_strategy(cls, compute_strategy: ActorPoolStrategy): - """Convert a legacy ActorPoolStrategy to an AutoscalingConfig.""" - # TODO(Clark): Remove this once the legacy compute strategies are deprecated. - assert isinstance(compute_strategy, ActorPoolStrategy) - return cls( - min_workers=compute_strategy.min_size, - max_workers=compute_strategy.max_size, - max_tasks_in_flight=compute_strategy.max_tasks_in_flight_per_actor - or DEFAULT_MAX_TASKS_IN_FLIGHT, - ready_to_total_workers_ratio=compute_strategy.ready_to_total_workers_ratio, - ) - - -class AutoscalingPolicy: - """Autoscaling policy for an actor pool, determining when the pool should be scaled - up and when it should be scaled down. - """ - - def __init__(self, autoscaling_config: "AutoscalingConfig"): - self._config = autoscaling_config - - @property - def min_workers(self) -> int: - """The minimum number of actors that must be in the actor pool.""" - return self._config.min_workers - - @property - def max_workers(self) -> int: - """The maximum number of actors that can be added to the actor pool.""" - return self._config.max_workers - - def should_scale_up(self, num_total_workers: int, num_running_workers: int) -> bool: - """Whether the actor pool should scale up by adding a new actor. - - Args: - num_total_workers: Total number of workers in actor pool. - num_running_workers: Number of currently running workers in actor pool. - - Returns: - Whether the actor pool should be scaled up by one actor. - """ - # TODO: Replace the ready-to-total-ratio heuristic with a a work queue - # heuristic such that scale-up is only triggered if the current pool doesn't - # have enough worker slots to process the work queue. - # TODO: Use profiling of the bundle arrival rate, worker startup - # time, and task execution time to tailor the work queue heuristic to the - # running workload and observed Ray performance. E.g. this could be done via an - # augmented EMA using a queueing model - if num_total_workers < self._config.min_workers: - # The actor pool does not reach the configured minimum size. - return True - else: - return ( - # 1. The actor pool will not exceed the configured maximum size. - num_total_workers < self._config.max_workers - # TODO: Remove this once we have a good work queue heuristic and our - # resource-based backpressure is working well. - # 2. At least 80% of the workers in the pool have already started. - # This will ensure that workers will be launched in parallel while - # bounding the worker pool to requesting 125% of the cluster's - # available resources. - and num_total_workers > 0 - and num_running_workers / num_total_workers - > self._config.ready_to_total_workers_ratio - ) - - def should_scale_down( - self, - num_total_workers: int, - num_idle_workers: int, - ) -> bool: - """Whether the actor pool should scale down by terminating an inactive actor. - - Args: - num_total_workers: Total number of workers in actor pool. - num_idle_workers: Number of currently idle workers in the actor pool. - - Returns: - Whether the actor pool should be scaled down by one actor. - """ - # TODO(Clark): Add an idleness timeout-based scale-down. - # TODO(Clark): Make the idleness timeout dynamically determined by bundle - # arrival rate, worker startup time, and task execution time. - return ( - # 1. The actor pool will not go below the configured minimum size. - num_total_workers > self._config.min_workers - # 2. The actor pool contains more than 50% idle workers. - and num_idle_workers / num_total_workers - > self._config.idle_to_total_workers_ratio - ) - - -class _ActorPool: +class _ActorPool(AutoscalingActorPool): """A pool of actors for map task execution. This class is in charge of tracking the number of in-flight tasks per actor, @@ -550,8 +351,23 @@ class _ActorPool: actors when the operator is done submitting work to the pool. """ - def __init__(self, max_tasks_in_flight: int = DEFAULT_MAX_TASKS_IN_FLIGHT): - self._max_tasks_in_flight = max_tasks_in_flight + def __init__( + self, + compute_strategy: ActorPoolStrategy, + create_actor_fn: Callable[[], Tuple[ActorHandle, ObjectRef[Any]]], + ): + self._min_size: int = compute_strategy.min_size + self._max_size: int = compute_strategy.max_size + self._max_tasks_in_flight: int = ( + compute_strategy.max_tasks_in_flight_per_actor + or DEFAULT_MAX_TASKS_IN_FLIGHT + ) + self._create_actor_fn = create_actor_fn + assert self._min_size >= 1 + assert self._max_size >= self._min_size + assert self._max_tasks_in_flight >= 1 + assert self._create_actor_fn is not None + # Number of tasks in flight per actor. self._num_tasks_in_flight: Dict[ray.actor.ActorHandle, int] = {} # Node id of each ready actor. @@ -565,6 +381,50 @@ def __init__(self, max_tasks_in_flight: int = DEFAULT_MAX_TASKS_IN_FLIGHT): self._locality_hits: int = 0 self._locality_misses: int = 0 + # === Overriding methods of AutoscalingActorPool === + + def min_size(self) -> int: + return self._min_size + + def max_size(self) -> int: + return self._max_size + + def current_size(self) -> int: + return self.num_pending_actors() + self.num_running_actors() + + def num_running_actors(self) -> int: + return len(self._num_tasks_in_flight) + + def num_active_actors(self) -> int: + return sum( + 1 if num_tasks_in_flight > 0 else 0 + for num_tasks_in_flight in self._num_tasks_in_flight.values() + ) + + def num_pending_actors(self) -> int: + return len(self._pending_actors) + + def max_tasks_in_flight_per_actor(self) -> int: + return self._max_tasks_in_flight + + def current_in_flight_tasks(self) -> int: + return sum(num for _, num in self._num_tasks_in_flight.items()) + + def scale_up(self, num_actors: int) -> int: + for _ in range(num_actors): + actor, ready_ref = self._create_actor_fn() + self.add_pending_actor(actor, ready_ref) + return num_actors + + def scale_down(self, num_actors: int) -> int: + num_killed = 0 + for _ in range(num_actors): + if self.kill_inactive_actor(): + num_killed += 1 + return num_killed + + # === End of overriding methods of AutoscalingActorPool === + def add_pending_actor(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef): """Adds a pending actor to the pool. @@ -657,16 +517,6 @@ def return_actor(self, actor: ray.actor.ActorHandle): def get_pending_actor_refs(self) -> List[ray.ObjectRef]: return list(self._pending_actors.keys()) - def num_total_actors(self) -> int: - """Return the total number of actors managed by this pool, including pending - actors - """ - return self.num_pending_actors() + self.num_running_actors() - - def num_running_actors(self) -> int: - """Return the number of running actors in the pool.""" - return len(self._num_tasks_in_flight) - def num_idle_actors(self) -> int: """Return the number of idle actors in the pool.""" return sum( @@ -674,10 +524,6 @@ def num_idle_actors(self) -> int: for tasks_in_flight in self._num_tasks_in_flight.values() ) - def num_pending_actors(self) -> int: - """Return the number of pending actors in the pool.""" - return len(self._pending_actors) - def num_free_slots(self) -> int: """Return the number of free slots for task execution.""" if not self._num_tasks_in_flight: @@ -687,13 +533,6 @@ def num_free_slots(self) -> int: for num_tasks_in_flight in self._num_tasks_in_flight.values() ) - def num_active_actors(self) -> int: - """Return the number of actors in the pool with at least one active task.""" - return sum( - 1 if num_tasks_in_flight > 0 else 0 - for num_tasks_in_flight in self._num_tasks_in_flight.values() - ) - def kill_inactive_actor(self) -> bool: """Kills a single pending or idle actor, if any actors are pending/idle. diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 069e5ff306c6..c94f039abb56 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -154,21 +154,14 @@ def create( elif isinstance(compute_strategy, ActorPoolStrategy): from ray.data._internal.execution.operators.actor_pool_map_operator import ( ActorPoolMapOperator, - AutoscalingConfig, - AutoscalingPolicy, ) - autoscaling_config = AutoscalingConfig.from_compute_strategy( - compute_strategy - ) - autoscaling_policy = AutoscalingPolicy(autoscaling_config) - return ActorPoolMapOperator( map_transformer, input_op, - autoscaling_policy=autoscaling_policy, - name=name, target_max_block_size=target_max_block_size, + compute_strategy=compute_strategy, + name=name, min_rows_per_bundle=min_rows_per_bundle, ray_remote_args=ray_remote_args, ) @@ -390,9 +383,7 @@ def base_resource_usage(self) -> ExecutionResources: raise NotImplementedError @abstractmethod - def incremental_resource_usage( - self, consider_autoscaling=True - ) -> ExecutionResources: + def incremental_resource_usage(self) -> ExecutionResources: raise NotImplementedError diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 2d84dd1bc111..dbf94cc83c0d 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -108,9 +108,7 @@ def current_processor_usage(self) -> ExecutionResources: gpu=self._ray_remote_args.get("num_gpus", 0) * num_active_workers, ) - def incremental_resource_usage( - self, consider_autoscaling=True - ) -> ExecutionResources: + def incremental_resource_usage(self) -> ExecutionResources: return ExecutionResources( cpu=self._ray_remote_args.get("num_cpus", 0), gpu=self._ray_remote_args.get("num_gpus", 0), diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index e8719e621960..26a6cf3681ac 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -11,9 +11,6 @@ ExecutionResources, ) from ray.data._internal.execution.interfaces.physical_operator import PhysicalOperator -from ray.data._internal.execution.operators.actor_pool_map_operator import ( - ActorPoolMapOperator, -) from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer from ray.data._internal.execution.operators.limit_operator import LimitOperator from ray.data._internal.execution.operators.map_operator import MapOperator @@ -419,16 +416,14 @@ def _update_reservation(self): ) # Calculate the minimum amount of resources to reserve. # 1. Make sure the reserved resources are at least to allow one task. - min_reserved = op.incremental_resource_usage( - consider_autoscaling=False - ).copy() + min_reserved = op.incremental_resource_usage().copy() # 2. To ensure that all GPUs are utilized, reserve enough resource budget # to launch one task for each worker. - if ( - isinstance(op, ActorPoolMapOperator) - and op.base_resource_usage().gpu > 0 - ): - min_reserved.object_store_memory *= op._autoscaling_policy.min_workers + if op.base_resource_usage().gpu > 0: + min_workers = sum( + pool.min_size() for pool in op.get_autoscaling_actor_pools() + ) + min_reserved.object_store_memory *= min_workers # Also include `reserved_for_op_outputs`. min_reserved.object_store_memory += self._reserved_for_op_outputs[op] # Total resources we want to reserve for this operator. diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 9546562e36a8..023e0ed4aff8 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -4,9 +4,7 @@ import uuid from typing import Dict, Iterator, List, Optional -from ray.data._internal.execution.autoscaling_requester import ( - get_or_create_autoscaling_requester_actor, -) +from ray.data._internal.execution.autoscaler import create_autoscaler from ray.data._internal.execution.backpressure_policy import ( BackpressurePolicy, get_backpressure_policies, @@ -22,7 +20,6 @@ from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer from ray.data._internal.execution.resource_manager import ResourceManager from ray.data._internal.execution.streaming_executor_state import ( - AutoscalingState, OpState, Topology, build_streaming_topology, @@ -63,7 +60,6 @@ def __init__(self, options: ExecutionOptions, dataset_tag: str = "unknown_datase self._global_info: Optional[ProgressBar] = None self._execution_id = uuid.uuid4().hex - self._autoscaling_state = AutoscalingState() # The executor can be shutdown while still running. self._shutdown_lock = threading.RLock() @@ -118,6 +114,11 @@ def execute( self._topology, _ = build_streaming_topology(dag, self._options) self._resource_manager = ResourceManager(self._topology, self._options) self._backpressure_policies = get_backpressure_policies(self._topology) + self._autoscaler = create_autoscaler( + self._topology, + self._resource_manager, + self._execution_id, + ) self._has_op_completed = {op: False for op in self._topology} @@ -194,9 +195,7 @@ def shutdown(self, execution_completed: bool = True): for op, state in self._topology.items(): op.shutdown() state.close_progress_bars() - # Make request for zero resources to autoscaler for this execution. - actor = get_or_create_autoscaling_requester_actor() - actor.request_resources.remote({}, self._execution_id) + self._autoscaler.on_executor_shutdown() def run(self): """Run the control loop in a helper thread. @@ -279,9 +278,8 @@ def _scheduling_loop_step(self, topology: Topology) -> bool: topology, self._resource_manager, self._backpressure_policies, + self._autoscaler, ensure_at_least_one_running=self._consumer_idling(), - execution_id=self._execution_id, - autoscaling_state=self._autoscaling_state, ) i = 0 @@ -295,9 +293,8 @@ def _scheduling_loop_step(self, topology: Topology) -> bool: topology, self._resource_manager, self._backpressure_policies, + self._autoscaler, ensure_at_least_one_running=self._consumer_idling(), - execution_id=self._execution_id, - autoscaling_state=self._autoscaling_state, ) update_operator_states(topology) diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index 2ebfb83e956c..fc5d7341a3b4 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -12,9 +12,7 @@ from typing import Dict, List, Optional, Tuple import ray -from ray.data._internal.execution.autoscaling_requester import ( - get_or_create_autoscaling_requester_actor, -) +from ray.data._internal.execution.autoscaler import Autoscaler from ray.data._internal.execution.backpressure_policy import BackpressurePolicy from ray.data._internal.execution.interfaces import ( ExecutionOptions, @@ -42,18 +40,6 @@ # operator to tracked streaming exec state. Topology = Dict[PhysicalOperator, "OpState"] -# Min number of seconds between two autoscaling requests. -MIN_GAP_BETWEEN_AUTOSCALING_REQUESTS = 20 - - -@dataclass -class AutoscalingState: - """State of the interaction between an executor and Ray autoscaler.""" - - # The timestamp of the latest resource request made to Ray autoscaler - # by an executor. - last_request_ts: int = 0 - class OpBufferQueue: """A FIFO queue to buffer RefBundles between upstream and downstream operators. @@ -156,6 +142,26 @@ def clear(self): self._num_per_split.clear() +@dataclass +class OpSchedulingStatus: + """The scheduling status of an operator. + + This will be updated each time when StreamingExecutor makes + a scheduling decision, i.e., in each `select_operator_to_run` + call. + """ + + # Whether the op was selected to run in the last scheduling + # decision. + selected: bool = False + # Whether the op was considered runnable in the last scheduling + # decision. + runnable: bool = False + # Whether the resources were sufficient for the operator to run + # in the last scheduling decision. + under_resource_limits: bool = False + + class OpState: """The execution state tracked for each PhysicalOperator. @@ -186,6 +192,7 @@ def __init__(self, op: PhysicalOperator, inqueues: List[OpBufferQueue]): # Used for StreamingExecutor to signal exception or end of execution self._finished: bool = False self._exception: Optional[Exception] = None + self._scheduling_status = OpSchedulingStatus() def __repr__(self): return f"OpState({self.op.name})" @@ -497,9 +504,8 @@ def select_operator_to_run( topology: Topology, resource_manager: ResourceManager, backpressure_policies: List[BackpressurePolicy], + autoscaler: Autoscaler, ensure_at_least_one_running: bool, - execution_id: str, - autoscaling_state: AutoscalingState, ) -> Optional[PhysicalOperator]: """Select an operator to run, if possible. @@ -526,6 +532,7 @@ def select_operator_to_run( in_backpressure = not under_resource_limits or any( not p.can_add_input(op) for p in backpressure_policies ) + op_runnable = False if ( not in_backpressure and not op.completed() @@ -533,21 +540,15 @@ def select_operator_to_run( and op.should_add_input() ): ops.append(op) + op_runnable = True + state._scheduling_status = OpSchedulingStatus( + selected=False, + runnable=op_runnable, + under_resource_limits=under_resource_limits, + ) + # Signal whether op in backpressure for stats collections op.notify_in_task_submission_backpressure(in_backpressure) - # Update the op in all cases to enable internal autoscaling, etc. - op.notify_resource_usage(state.num_queued(), under_resource_limits) - - # If no ops are allowed to execute due to resource constraints, try to trigger - # cluster scale-up. - if not ops and any(state.num_queued() > 0 for state in topology.values()): - now = time.time() - if ( - now - > autoscaling_state.last_request_ts + MIN_GAP_BETWEEN_AUTOSCALING_REQUESTS - ): - autoscaling_state.last_request_ts = now - _try_to_scale_up_cluster(topology, execution_id) # To ensure liveness, allow at least 1 op to run regardless of limits. This is # gated on `ensure_at_least_one_running`, which is set if the consumer is blocked. @@ -569,58 +570,16 @@ def select_operator_to_run( # Run metadata-only operators first. After that, choose the operator with the least # memory usage. - return min( + selected_op = min( ops, key=lambda op: ( not op.throttling_disabled(), resource_manager.get_op_usage(op).object_store_memory, ), ) - - -def _try_to_scale_up_cluster(topology: Topology, execution_id: str): - """Try to scale up the cluster to accomodate the provided in-progress workload. - - This makes a resource request to Ray's autoscaler consisting of the current, - aggregate usage of all operators in the DAG + the incremental usage of all operators - that are ready for dispatch (i.e. that have inputs queued). If the autoscaler were - to grant this resource request, it would allow us to dispatch one task for every - ready operator. - - Note that this resource request does not take the global resource limits or the - liveness policy into account; it only tries to make the existing resource usage + - one more task per ready operator feasible in the cluster. - - Args: - topology: The execution state of the in-progress workload for which we wish to - request more resources. - """ - # Get resource usage for all ops + additional resources needed to launch one more - # task for each ready op. - resource_request = [] - - def to_bundle(resource: ExecutionResources) -> Dict: - req = {} - if resource.cpu: - req["CPU"] = math.ceil(resource.cpu) - if resource.gpu: - req["GPU"] = math.ceil(resource.gpu) - return req - - for op, state in topology.items(): - per_task_resource = op.incremental_resource_usage() - task_bundle = to_bundle(per_task_resource) - resource_request.extend([task_bundle] * op.num_active_tasks()) - # Only include incremental resource usage for ops that are ready for - # dispatch. - if state.num_queued() > 0: - # TODO(Clark): Scale up more aggressively by adding incremental resource - # usage for more than one bundle in the queue for this op? - resource_request.append(task_bundle) - - # Make autoscaler resource request. - actor = get_or_create_autoscaling_requester_actor() - actor.request_resources.remote(resource_request, execution_id) + topology[selected_op]._scheduling_status.selected = True + autoscaler.try_trigger_scaling() + return selected_op def _execution_allowed(op: PhysicalOperator, resource_manager: ResourceManager) -> bool: diff --git a/python/ray/data/context.py b/python/ray/data/context.py index de1464e56e90..6c68595cf024 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -73,7 +73,9 @@ DEFAULT_TRACE_ALLOCATIONS = bool(int(os.environ.get("RAY_DATA_TRACE_ALLOCATIONS", "0"))) -DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT = False +DEFAULT_LOG_INTERNAL_STACK_TRACE_TO_STDOUT = env_bool( + "RAY_DATA_LOG_INTERNAL_STACK_TRACE_TO_STDOUT", False +) DEFAULT_USE_RAY_TQDM = bool(int(os.environ.get("RAY_TQDM", "1"))) diff --git a/python/ray/data/tests/test_actor_pool_map_operator.py b/python/ray/data/tests/test_actor_pool_map_operator.py index d094aba9b309..07eebfe0dccb 100644 --- a/python/ray/data/tests/test_actor_pool_map_operator.py +++ b/python/ray/data/tests/test_actor_pool_map_operator.py @@ -1,15 +1,12 @@ import collections import time +import unittest import pytest import ray from ray.data._internal.compute import ActorPoolStrategy -from ray.data._internal.execution.operators.actor_pool_map_operator import ( - AutoscalingConfig, - AutoscalingPolicy, - _ActorPool, -) +from ray.data._internal.execution.operators.actor_pool_map_operator import _ActorPool from ray.data._internal.execution.util import make_ref_bundles from ray.tests.conftest import * # noqa @@ -23,28 +20,72 @@ def get_location(self) -> str: return self.node_id -class TestActorPool: - def _add_ready_worker(self, pool: _ActorPool) -> ray.actor.ActorHandle: - actor = PoolWorker.remote() +class TestActorPool(unittest.TestCase): + def setup_class(self): + self._last_created_actor_and_ready_ref = (None, None) + self._actor_node_id = "node1" + ray.init(num_cpus=4) + + def teardown_class(self): + ray.shutdown() + + def _create_actor_fn(self): + actor = PoolWorker.remote(self._actor_node_id) ready_ref = actor.get_location.remote() - pool.add_pending_actor(actor, ready_ref) - # Wait until actor has started. + self._last_created_actor_and_ready_ref = (actor, ready_ref) + return actor, ready_ref + + def _create_actor_pool( + self, + min_size=1, + max_size=4, + max_tasks_in_flight=4, + ): + pool = _ActorPool( + compute_strategy=ActorPoolStrategy( + min_size=min_size, + max_size=max_size, + max_tasks_in_flight_per_actor=max_tasks_in_flight, + ), + create_actor_fn=self._create_actor_fn, + ) + return pool + + def _add_pending_actor(self, pool: _ActorPool, node_id="node1"): + self._actor_node_id = node_id + assert pool.scale_up(1) == 1 + actor, ready_ref = self._last_created_actor_and_ready_ref + return actor, ready_ref + + def _wait_for_actor_ready(self, pool: _ActorPool, ready_ref): ray.get(ready_ref) - # Mark actor as running. - has_actor = pool.pending_to_running(ready_ref) - assert has_actor + pool.pending_to_running(ready_ref) + + def _add_ready_actor(self, pool: _ActorPool, node_id="node1"): + actor, ready_ref = self._add_pending_actor(pool, node_id) + self._wait_for_actor_ready(pool, ready_ref) return actor - def test_add_pending(self, ray_start_regular_shared): + def test_basic_config(self): + pool = self._create_actor_pool( + min_size=1, + max_size=4, + max_tasks_in_flight=4, + ) + assert pool.min_size() == 1 + assert pool.max_size() == 4 + assert pool.current_size() == 0 + assert pool.max_tasks_in_flight_per_actor() == 4 + + def test_add_pending(self): # Test that pending actor is added in the correct state. - pool = _ActorPool() - actor = PoolWorker.remote() - ready_ref = actor.get_location.remote() - pool.add_pending_actor(actor, ready_ref) + pool = self._create_actor_pool() + _, ready_ref = self._add_pending_actor(pool) # Check that the pending actor is not pickable. + assert pool.pick_actor() is None # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 1 + assert pool.current_size() == 1 assert pool.num_pending_actors() == 1 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 @@ -53,33 +94,33 @@ def test_add_pending(self, ray_start_regular_shared): # Check that ready future is returned. assert pool.get_pending_actor_refs() == [ready_ref] - def test_pending_to_running(self, ray_start_regular_shared): + def test_pending_to_running(self): # Test that pending actor is correctly transitioned to running. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Check that the actor is pickable. picked_actor = pool.pick_actor() assert picked_actor == actor # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 1 + assert pool.current_size() == 1 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 1 assert pool.num_active_actors() == 1 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 3 - def test_repeated_picking(self, ray_start_regular_shared): + def test_repeated_picking(self): # Test that we can repeatedly pick the same actor. - pool = _ActorPool(max_tasks_in_flight=999) - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool(max_tasks_in_flight=999) + actor = self._add_ready_actor(pool) for _ in range(10): picked_actor = pool.pick_actor() assert picked_actor == actor - def test_return_actor(self, ray_start_regular_shared): + def test_return_actor(self): # Test that we can return an actor as many times as we've picked it. - pool = _ActorPool(max_tasks_in_flight=999) - self._add_ready_worker(pool) + pool = self._create_actor_pool(max_tasks_in_flight=999) + self._add_ready_actor(pool) for _ in range(10): picked_actor = pool.pick_actor() # Return the actor as many times as it was picked. @@ -90,17 +131,17 @@ def test_return_actor(self, ray_start_regular_shared): with pytest.raises(AssertionError): pool.return_actor(picked_actor) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 1 + assert pool.current_size() == 1 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 1 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 1 # Actor should now be idle. assert pool.num_free_slots() == 999 - def test_pick_max_tasks_in_flight(self, ray_start_regular_shared): + def test_pick_max_tasks_in_flight(self): # Test that we can't pick an actor beyond the max_tasks_in_flight cap. - pool = _ActorPool(max_tasks_in_flight=2) - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool(max_tasks_in_flight=2) + actor = self._add_ready_actor(pool) assert pool.num_free_slots() == 2 assert pool.pick_actor() == actor assert pool.num_free_slots() == 1 @@ -109,40 +150,40 @@ def test_pick_max_tasks_in_flight(self, ray_start_regular_shared): # Check that the 3rd pick doesn't return the actor. assert pool.pick_actor() is None - def test_pick_ordering_lone_idle(self, ray_start_regular_shared): + def test_pick_ordering_lone_idle(self): # Test that a lone idle actor is the one that's picked. - pool = _ActorPool() - self._add_ready_worker(pool) + pool = self._create_actor_pool() + self._add_ready_actor(pool) # Ensure that actor has been picked once. pool.pick_actor() # Add a new, idle actor. - actor2 = self._add_ready_worker(pool) + actor2 = self._add_ready_actor(pool) # Check that picked actor is the idle newly added actor. picked_actor = pool.pick_actor() assert picked_actor == actor2 - def test_pick_ordering_full_order(self, ray_start_regular_shared): + def test_pick_ordering_full_order(self): # Test that the least loaded actor is always picked. - pool = _ActorPool() + pool = self._create_actor_pool() # Add 4 actors to the pool. - actors = [self._add_ready_worker(pool) for _ in range(4)] + actors = [self._add_ready_actor(pool) for _ in range(4)] # Pick 4 actors. picked_actors = [pool.pick_actor() for _ in range(4)] # Check that the 4 distinct actors that were added to the pool were all # returned. assert set(picked_actors) == set(actors) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 4 + assert pool.current_size() == 4 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 4 assert pool.num_active_actors() == 4 assert pool.num_idle_actors() == 0 - def test_pick_all_max_tasks_in_flight(self, ray_start_regular_shared): + def test_pick_all_max_tasks_in_flight(self): # Test that max_tasks_in_flight cap applies to all actors in pool. - pool = _ActorPool(max_tasks_in_flight=2) + pool = self._create_actor_pool(max_tasks_in_flight=2) # Add 4 actors to the pool. - actors = [self._add_ready_worker(pool) for _ in range(4)] + actors = [self._add_ready_actor(pool) for _ in range(4)] picked_actors = [pool.pick_actor() for _ in range(8)] pick_counts = collections.Counter(picked_actors) # Check that picks were evenly distributed over the pool. @@ -150,14 +191,14 @@ def test_pick_all_max_tasks_in_flight(self, ray_start_regular_shared): for actor, count in pick_counts.items(): assert actor in actors assert count == 2 - # Check that the next pick doesn't return an ctor. + # Check that the next pick doesn't return an actor. assert pool.pick_actor() is None - def test_pick_ordering_with_returns(self, ray_start_regular_shared): + def test_pick_ordering_with_returns(self): # Test that pick ordering works with returns. - pool = _ActorPool() - actor1 = self._add_ready_worker(pool) - actor2 = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor1 = self._add_ready_actor(pool) + actor2 = self._add_ready_actor(pool) picked_actors = [pool.pick_actor() for _ in range(2)] # Double-check that both actors were picked. assert set(picked_actors) == {actor1, actor2} @@ -166,12 +207,10 @@ def test_pick_ordering_with_returns(self, ray_start_regular_shared): # Check that actor 2 is the next actor that's picked. assert pool.pick_actor() == actor2 - def test_kill_inactive_pending_actor(self, ray_start_regular_shared): + def test_kill_inactive_pending_actor(self): # Test that a pending actor is killed on the kill_inactive_actor() call. - pool = _ActorPool() - actor = PoolWorker.remote() - ready_ref = actor.get_location.remote() - pool.add_pending_actor(actor, ready_ref) + pool = self._create_actor_pool() + actor, _ = self._add_pending_actor(pool) # Kill inactive actor. killed = pool.kill_inactive_actor() # Check that an actor was killed. @@ -184,17 +223,17 @@ def test_kill_inactive_pending_actor(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_kill_inactive_idle_actor(self, ray_start_regular_shared): + def test_kill_inactive_idle_actor(self): # Test that a idle actor is killed on the kill_inactive_actor() call. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Kill inactive actor. killed = pool.kill_inactive_actor() # Check that an actor was killed. @@ -207,17 +246,17 @@ def test_kill_inactive_idle_actor(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_kill_inactive_active_actor_not_killed(self, ray_start_regular_shared): + def test_kill_inactive_active_actor_not_killed(self): # Test that active actors are NOT killed on the kill_inactive_actor() call. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Pick actor (and double-check that the actor was picked). assert pool.pick_actor() == actor # Kill inactive actor. @@ -227,16 +266,14 @@ def test_kill_inactive_active_actor_not_killed(self, ray_start_regular_shared): # Check that the active actor is still in the pool. assert pool.pick_actor() == actor - def test_kill_inactive_pending_over_idle(self, ray_start_regular_shared): + def test_kill_inactive_pending_over_idle(self): # Test that a killing pending actors is prioritized over killing idle actors on # the kill_inactive_actor() call. - pool = _ActorPool() + pool = self._create_actor_pool() # Add pending worker. - pending_actor = PoolWorker.remote() - ready_ref = pending_actor.get_location.remote() - pool.add_pending_actor(pending_actor, ready_ref) + pending_actor, _ = self._add_pending_actor(pool) # Add idle worker. - idle_actor = self._add_ready_worker(pool) + idle_actor = self._add_ready_actor(pool) # Kill inactive actor. killed = pool.kill_inactive_actor() # Check that an actor was killed. @@ -252,19 +289,17 @@ def test_kill_inactive_pending_over_idle(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(pending_actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 1 + assert pool.current_size() == 1 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 1 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 1 assert pool.num_free_slots() == 4 - def test_kill_all_inactive_pending_actor_killed(self, ray_start_regular_shared): + def test_kill_all_inactive_pending_actor_killed(self): # Test that pending actors are killed on the kill_all_inactive_actors() call. - pool = _ActorPool() - actor = PoolWorker.remote() - ready_ref = actor.get_location.remote() - pool.add_pending_actor(actor, ready_ref) + pool = self._create_actor_pool() + actor, ready_ref = self._add_pending_actor(pool) # Kill inactive actors. pool.kill_all_inactive_actors() # Check that actor is not in pool. @@ -278,17 +313,17 @@ def test_kill_all_inactive_pending_actor_killed(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_kill_all_inactive_idle_actor_killed(self, ray_start_regular_shared): + def test_kill_all_inactive_idle_actor_killed(self): # Test that idle actors are killed on the kill_all_inactive_actors() call. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Kill inactive actors. pool.kill_all_inactive_actors() # Check that actor is not in pool. @@ -299,17 +334,17 @@ def test_kill_all_inactive_idle_actor_killed(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_kill_all_inactive_active_actor_not_killed(self, ray_start_regular_shared): + def test_kill_all_inactive_active_actor_not_killed(self): # Test that active actors are NOT killed on the kill_all_inactive_actors() call. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Pick actor (and double-check that the actor was picked). assert pool.pick_actor() == actor # Kill inactive actors. @@ -317,13 +352,11 @@ def test_kill_all_inactive_active_actor_not_killed(self, ray_start_regular_share # Check that the active actor is still in the pool. assert pool.pick_actor() == actor - def test_kill_all_inactive_future_idle_actors_killed( - self, ray_start_regular_shared - ): + def test_kill_all_inactive_future_idle_actors_killed(self): # Test that future idle actors are killed after the kill_all_inactive_actors() # call. - pool = _ActorPool() - actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + actor = self._add_ready_actor(pool) # Pick actor (and double-check that the actor was picked). assert pool.pick_actor() == actor # Kill inactive actors, of which there are currently none. @@ -342,29 +375,27 @@ def test_kill_all_inactive_future_idle_actors_killed( with pytest.raises(ray.exceptions.RayActorError): ray.get(actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_kill_all_inactive_mixture(self, ray_start_regular_shared): + def test_kill_all_inactive_mixture(self): # Test that in a mixture of pending, idle, and active actors, only the pending # and idle actors are killed on the kill_all_inactive_actors() call. - pool = _ActorPool() + pool = self._create_actor_pool() # Add active actor. - actor1 = self._add_ready_worker(pool) + actor1 = self._add_ready_actor(pool) # Pick actor (and double-check that the actor was picked). assert pool.pick_actor() == actor1 # Add idle actor. - self._add_ready_worker(pool) + self._add_ready_actor(pool) # Add pending actor. - actor3 = PoolWorker.remote() - ready_ref = actor3.get_location.remote() - pool.add_pending_actor(actor3, ready_ref) + actor3, ready_ref = self._add_pending_actor(pool) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 3 + assert pool.current_size() == 3 assert pool.num_pending_actors() == 1 assert pool.num_running_actors() == 2 assert pool.num_active_actors() == 1 @@ -393,20 +424,20 @@ def test_kill_all_inactive_mixture(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(actor1.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 0 - def test_all_actors_killed(self, ray_start_regular_shared): + def test_all_actors_killed(self): # Test that all actors are killed after the kill_all_actors() call. - pool = _ActorPool() - active_actor = self._add_ready_worker(pool) + pool = self._create_actor_pool() + active_actor = self._add_ready_actor(pool) # Pick actor (and double-check that the actor was picked). assert pool.pick_actor() == active_actor - idle_actor = self._add_ready_worker(pool) + idle_actor = self._add_ready_actor(pool) # Kill all actors, including active actors. pool.kill_all_actors() # Check that the pool is empty. @@ -419,7 +450,7 @@ def test_all_actors_killed(self, ray_start_regular_shared): with pytest.raises(ray.exceptions.RayActorError): ray.get(active_actor.get_location.remote()) # Check that the per-state pool sizes are as expected. - assert pool.num_total_actors() == 0 + assert pool.current_size() == 0 assert pool.num_pending_actors() == 0 assert pool.num_running_actors() == 0 assert pool.num_active_actors() == 0 @@ -427,7 +458,7 @@ def test_all_actors_killed(self, ray_start_regular_shared): assert pool.num_free_slots() == 0 def test_locality_manager_actor_ranking(self): - pool = _ActorPool(max_tasks_in_flight=2) + pool = self._create_actor_pool(max_tasks_in_flight=2) # Setup bundle mocks. bundles = make_ref_bundles([[0] for _ in range(10)]) @@ -437,16 +468,8 @@ def test_locality_manager_actor_ranking(self): pool._get_location = lambda b: fake_loc_map[b] # Setup an actor on each node. - actor1 = PoolWorker.remote(node_id="node1") - ready_ref = actor1.get_location.remote() - pool.add_pending_actor(actor1, ready_ref) - ray.get(ready_ref) - pool.pending_to_running(ready_ref) - actor2 = PoolWorker.remote(node_id="node2") - ready_ref = actor2.get_location.remote() - pool.add_pending_actor(actor2, ready_ref) - ray.get(ready_ref) - pool.pending_to_running(ready_ref) + actor1 = self._add_ready_actor(pool, node_id="node1") + actor2 = self._add_ready_actor(pool, node_id="node2") # Actors on node1 should be preferred. res1 = pool.pick_actor(bundles[0]) @@ -463,7 +486,7 @@ def test_locality_manager_actor_ranking(self): assert res5 is None def test_locality_manager_busyness_ranking(self): - pool = _ActorPool(max_tasks_in_flight=2) + pool = self._create_actor_pool(max_tasks_in_flight=2) # Setup bundle mocks. bundles = make_ref_bundles([[0] for _ in range(10)]) @@ -474,16 +497,8 @@ def test_locality_manager_busyness_ranking(self): pool._get_location = lambda b: fake_loc_map[b] # Setup two actors on the same node. - actor1 = PoolWorker.remote(node_id="node1") - ready_ref = actor1.get_location.remote() - pool.add_pending_actor(actor1, ready_ref) - ray.get(ready_ref) - pool.pending_to_running(ready_ref) - actor2 = PoolWorker.remote(node_id="node1") - ready_ref = actor2.get_location.remote() - pool.add_pending_actor(actor2, ready_ref) - ray.get(ready_ref) - pool.pending_to_running(ready_ref) + actor1 = self._add_ready_actor(pool, node_id="node1") + actor2 = self._add_ready_actor(pool, node_id="node2") # Fake actor 2 as more busy. pool._num_tasks_in_flight[actor2] = 1 @@ -500,166 +515,36 @@ def test_locality_manager_busyness_ranking(self): assert res3 is None -class TestAutoscalingConfig: - def test_min_workers_validation(self): - # Test min_workers positivity validation. - with pytest.raises(ValueError): - AutoscalingConfig(min_workers=0, max_workers=2) - - def test_max_workers_validation(self): - # Test max_workers not being less than min_workers validation. - with pytest.raises(ValueError): - AutoscalingConfig(min_workers=3, max_workers=2) - - def test_max_tasks_in_flight_validation(self): - # Test max_tasks_in_flight positivity validation. - with pytest.raises(ValueError): - AutoscalingConfig(min_workers=1, max_workers=2, max_tasks_in_flight=0) - - def test_full_specification(self): - # Basic regression test for full specification. - config = AutoscalingConfig( - min_workers=2, - max_workers=100, - max_tasks_in_flight=3, - ready_to_total_workers_ratio=0.8, - idle_to_total_workers_ratio=0.25, - ) - assert config.min_workers == 2 - assert config.max_workers == 100 - assert config.max_tasks_in_flight == 3 - assert config.ready_to_total_workers_ratio == 0.8 - assert config.idle_to_total_workers_ratio == 0.25 - - def test_from_compute(self): - # Test that construction from ActorPoolStrategy works as expected. - compute = ActorPoolStrategy( - min_size=2, max_size=5, max_tasks_in_flight_per_actor=3 - ) - config = AutoscalingConfig.from_compute_strategy(compute) - assert config.min_workers == 2 - assert config.max_workers == 5 - assert config.max_tasks_in_flight == 3 - assert config.ready_to_total_workers_ratio == 0.8 - assert config.idle_to_total_workers_ratio == 0.5 - - -class TestAutoscalingPolicy: - def test_min_workers(self): - # Test that the autoscaling policy forwards the config's min_workers. - config = AutoscalingConfig(min_workers=1, max_workers=4) - policy = AutoscalingPolicy(config) - assert policy.min_workers == 1 - - def test_max_workers(self): - # Test that the autoscaling policy forwards the config's max_workers. - config = AutoscalingConfig(min_workers=1, max_workers=4) - policy = AutoscalingPolicy(config) - assert policy.max_workers == 4 - - def test_should_scale_up_over_min_workers(self): - config = AutoscalingConfig(min_workers=1, max_workers=4) - policy = AutoscalingPolicy(config) - num_total_workers = 0 - num_running_workers = 0 - # Should scale up since under pool min workers. - assert policy.should_scale_up(num_total_workers, num_running_workers) - - def test_should_scale_up_over_max_workers(self): - # Test that scale-up is blocked if the pool would go over the configured max - # workers. - config = AutoscalingConfig(min_workers=1, max_workers=4) - policy = AutoscalingPolicy(config) - num_total_workers = 4 - num_running_workers = 4 - # Shouldn't scale up due to pool max workers. - assert not policy.should_scale_up(num_total_workers, num_running_workers) - - num_total_workers = 3 - num_running_workers = 3 - # Should scale up since under pool max workers. - assert policy.should_scale_up(num_total_workers, num_running_workers) - - def test_should_scale_up_ready_to_total_ratio(self): - # Test that scale-up is blocked if under the ready workers to total workers - # ratio. - config = AutoscalingConfig( - min_workers=1, max_workers=4, ready_to_total_workers_ratio=0.5 - ) - policy = AutoscalingPolicy(config) - - num_total_workers = 2 - num_running_workers = 1 - # Shouldn't scale up due to being under ready workers to total workers ratio. - assert not policy.should_scale_up(num_total_workers, num_running_workers) - - num_total_workers = 3 - num_running_workers = 2 - # Shouldn scale up due to being over ready workers to total workers ratio. - assert policy.should_scale_up(num_total_workers, num_running_workers) - - def test_should_scale_down_min_workers(self): - # Test that scale-down is blocked if the pool would go under the configured min - # workers. - config = AutoscalingConfig(min_workers=2, max_workers=4) - policy = AutoscalingPolicy(config) - num_total_workers = 2 - num_idle_workers = 2 - # Shouldn't scale down due to pool min workers. - assert not policy.should_scale_down(num_total_workers, num_idle_workers) - - num_total_workers = 3 - num_idle_workers = 3 - # Should scale down since over pool min workers. - assert policy.should_scale_down(num_total_workers, num_idle_workers) - - def test_should_scale_down_idle_to_total_ratio(self): - # Test that scale-down is blocked if under the idle workers to total workers - # ratio. - config = AutoscalingConfig( - min_workers=1, max_workers=4, idle_to_total_workers_ratio=0.5 - ) - policy = AutoscalingPolicy(config) - num_total_workers = 4 - num_idle_workers = 1 - # Shouldn't scale down due to being under idle workers to total workers ratio. - assert not policy.should_scale_down(num_total_workers, num_idle_workers) - - num_total_workers = 4 - num_idle_workers = 3 - # Should scale down due to being over idle workers to total workers ratio. - assert policy.should_scale_down(num_total_workers, num_idle_workers) - - def test_start_actor_timeout(ray_start_regular_shared): - """Tests that ActorPoolMapOperator raises an exception on - timeout while waiting for actors.""" - - class UDFClass: - def __call__(self, x): - return x - - from ray.data._internal.execution.operators import actor_pool_map_operator - from ray.exceptions import GetTimeoutError - - original_timeout = actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC - actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC = 1 - - with pytest.raises( - GetTimeoutError, - match=( - "Timed out while starting actors. This may mean that the cluster " - "does not have enough resources for the requested actor pool." - ), - ): - # Specify an unachievable resource requirement to ensure - # we timeout while waiting for actors. - ray.data.range(10).map_batches( - UDFClass, - batch_size=1, - compute=ray.data.ActorPoolStrategy(size=5), - num_gpus=100, - ).take_all() - actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC = original_timeout +def test_start_actor_timeout(ray_start_regular_shared): + """Tests that ActorPoolMapOperator raises an exception on + timeout while waiting for actors.""" + + class UDFClass: + def __call__(self, x): + return x + + from ray.data._internal.execution.operators import actor_pool_map_operator + from ray.exceptions import GetTimeoutError + + original_timeout = actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC + actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC = 1 + + with pytest.raises( + GetTimeoutError, + match=( + "Timed out while starting actors. This may mean that the cluster " + "does not have enough resources for the requested actor pool." + ), + ): + # Specify an unachievable resource requirement to ensure + # we timeout while waiting for actors. + ray.data.range(10).map_batches( + UDFClass, + batch_size=1, + compute=ray.data.ActorPoolStrategy(size=5), + num_gpus=100, + ).take_all() + actor_pool_map_operator.DEFAULT_WAIT_FOR_MIN_ACTORS_SEC = original_timeout if __name__ == "__main__": diff --git a/python/ray/data/tests/test_autoscaler.py b/python/ray/data/tests/test_autoscaler.py new file mode 100644 index 000000000000..e65fbdf07efb --- /dev/null +++ b/python/ray/data/tests/test_autoscaler.py @@ -0,0 +1,189 @@ +from contextlib import contextmanager +from unittest.mock import MagicMock + +from pytest_shutil.workspace import pytest + +from ray.data import ExecutionResources +from ray.data._internal.execution.autoscaler.default_autoscaler import DefaultAutoscaler + + +def test_actor_pool_scaling(): + """Test `_actor_pool_should_scale_up` and `_actor_pool_should_scale_down` + in `DefaultAutoscaler`""" + + autoscaler = DefaultAutoscaler( + topology=MagicMock(), + resource_manager=MagicMock(), + execution_id="execution_id", + actor_pool_scaling_up_threshold=0.8, + actor_pool_scaling_down_threshold=0.5, + ) + + # Current actor pool utilization is 0.9, which is above the threshold. + actor_pool = MagicMock( + min_size=MagicMock(return_value=5), + max_size=MagicMock(return_value=15), + current_size=MagicMock(return_value=10), + num_active_actors=MagicMock(return_value=9), + num_free_task_slots=MagicMock(return_value=5), + ) + + op = MagicMock( + completed=MagicMock(return_value=False), + _inputs_complete=False, + internal_queue_size=MagicMock(return_value=1), + ) + op_state = MagicMock(num_queued=MagicMock(return_value=10)) + op_scheduling_status = MagicMock(under_resource_limits=False) + op_state._scheduling_status = op_scheduling_status + + @contextmanager + def patch(mock, attr, value, is_method=True): + original = getattr(mock, attr) + if is_method: + value = MagicMock(return_value=value) + setattr(mock, attr, value) + yield + setattr(mock, attr, original) + + # === Test scaling up === + + def assert_should_scale_up(expected): + nonlocal actor_pool, op, op_state + + assert ( + autoscaler._actor_pool_should_scale_up( + actor_pool=actor_pool, + op=op, + op_state=op_state, + ) + == expected + ) + + # Should scale up since the util above the threshold. + assert autoscaler._calculate_actor_pool_util(actor_pool) == 0.9 + assert_should_scale_up(True) + + # Shouldn't scale up since the util is below the threshold. + with patch(actor_pool, "num_active_actors", 7): + assert autoscaler._calculate_actor_pool_util(actor_pool) == 0.7 + assert_should_scale_up(False) + + # Shouldn't scale up since we have reached the max size. + with patch(actor_pool, "current_size", 15): + assert_should_scale_up(False) + + # Should scale up since the pool is below the min size. + with patch(actor_pool, "current_size", 4): + assert_should_scale_up(True) + + # Shouldn't scale up since if the op is completed, or + # the op has no more inputs. + with patch(op, "completed", True): + assert_should_scale_up(False) + with patch(op, "_inputs_complete", True, is_method=False): + with patch(op, "internal_queue_size", 0): + assert_should_scale_up(False) + + # Shouldn't scale up since the op is under resource limits. + with patch( + op_scheduling_status, + "under_resource_limits", + True, + is_method=False, + ): + assert_should_scale_up(False) + + # Shouldn't scale up since the op has enough free slots for + # the existing inputs. + with patch(op_state, "num_queued", 5): + assert_should_scale_up(False) + + # === Test scaling down === + + def assert_should_scale_down(expected): + assert ( + autoscaler._actor_pool_should_scale_down( + actor_pool=actor_pool, + op=op, + ) + == expected + ) + + # Shouldn't scale down since the util above the threshold. + assert autoscaler._calculate_actor_pool_util(actor_pool) == 0.9 + assert_should_scale_down(False) + + # Should scale down since the util is below the threshold. + with patch(actor_pool, "num_active_actors", 4): + assert autoscaler._calculate_actor_pool_util(actor_pool) == 0.4 + assert_should_scale_down(True) + + # Should scale down since the pool is above the max size. + with patch(actor_pool, "current_size", 16): + assert_should_scale_down(True) + + # Shouldn't scale down since we have reached the min size. + with patch(actor_pool, "current_size", 5): + assert_should_scale_down(False) + + # Should scale down since if the op is completed, or + # the op has no more inputs. + with patch(op, "completed", True): + assert_should_scale_down(True) + with patch(op, "_inputs_complete", True, is_method=False): + with patch(op, "internal_queue_size", 0): + assert_should_scale_down(True) + + +def test_cluster_scaling(): + """Test `_try_scale_up_cluster` in `DefaultAutoscaler`""" + op1 = MagicMock( + input_dependencies=[], + incremental_resource_usage=MagicMock( + return_value=ExecutionResources(cpu=1, gpu=0, object_store_memory=0) + ), + num_active_tasks=MagicMock(return_value=1), + ) + op_state1 = MagicMock( + num_queued=MagicMock(return_value=0), + _scheduling_status=MagicMock( + runnable=False, + ), + ) + op2 = MagicMock( + input_dependencies=[op1], + incremental_resource_usage=MagicMock( + return_value=ExecutionResources(cpu=2, gpu=0, object_store_memory=0) + ), + num_active_tasks=MagicMock(return_value=1), + ) + op_state2 = MagicMock( + num_queued=MagicMock(return_value=1), + _scheduling_status=MagicMock( + runnable=False, + ), + ) + topology = { + op1: op_state1, + op2: op_state2, + } + + autoscaler = DefaultAutoscaler( + topology=topology, + resource_manager=MagicMock(), + execution_id="execution_id", + ) + + autoscaler._send_resource_request = MagicMock() + autoscaler._try_scale_up_cluster() + + autoscaler._send_resource_request.assert_called_once_with( + [{"CPU": 1}, {"CPU": 2}, {"CPU": 2}] + ) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_executor_resource_management.py b/python/ray/data/tests/test_executor_resource_management.py index f95fe1b01549..589e904171c3 100644 --- a/python/ray/data/tests/test_executor_resource_management.py +++ b/python/ray/data/tests/test_executor_resource_management.py @@ -287,8 +287,8 @@ def test_actor_pool_resource_reporting(ray_start_10_cpus_shared, restore_data_co * data_context.target_max_block_size ) assert op.base_resource_usage() == ExecutionResources(cpu=2, gpu=0) - # All actors are idle (pending creation), therefore shouldn't need to scale up when - # submitting a new task, so incremental resource usage should be 0. + # `incremental_resource_usage` should always report 0 CPU and GPU, as + # it doesn't consider scaling-up. assert op.incremental_resource_usage() == ExecutionResources( cpu=0, gpu=0, object_store_memory=inc_obj_store_mem ) @@ -300,8 +300,6 @@ def test_actor_pool_resource_reporting(ray_start_10_cpus_shared, restore_data_co # Add inputs. for i in range(4): - # Pool is still idle while waiting for actors to start, so additional tasks - # shouldn't trigger scale-up, so incremental resource usage should still be 0. assert op.incremental_resource_usage() == ExecutionResources( cpu=0, gpu=0, object_store_memory=inc_obj_store_mem ) @@ -318,11 +316,6 @@ def test_actor_pool_resource_reporting(ray_start_10_cpus_shared, restore_data_co assert op.num_active_tasks() == 2 run_op_tasks_sync(op, only_existing=True) - # Now that both actors have started, a new task would trigger scale-up, so - inc_usage = op.incremental_resource_usage() - assert inc_usage.cpu == 1, inc_usage - assert inc_usage.gpu == 0, inc_usage - # Actors have now started and the pool is actively running tasks. assert op.current_processor_usage() == ExecutionResources(cpu=2, gpu=0) assert op.metrics.obj_store_mem_internal_inqueue == 0 @@ -341,7 +334,9 @@ def test_actor_pool_resource_reporting(ray_start_10_cpus_shared, restore_data_co # Wait until tasks are done. run_op_tasks_sync(op) - # Work is done and the pool has been scaled down. + # Work is done, scale down the actor pool. + for pool in op.get_autoscaling_actor_pools(): + pool.scale_down(pool.current_size()) assert op.current_processor_usage() == ExecutionResources(cpu=0, gpu=0) assert op.metrics.obj_store_mem_internal_inqueue == 0 assert op.metrics.obj_store_mem_internal_outqueue == pytest.approx( @@ -355,8 +350,9 @@ def test_actor_pool_resource_reporting(ray_start_10_cpus_shared, restore_data_co while op.has_next(): op.get_next() - # Work is done, pool has been scaled down, and outputs have been consumed. - assert op.current_processor_usage() == ExecutionResources(cpu=0, gpu=0) + # Work is done, scale down the actor pool, and outputs have been consumed. + for pool in op.get_autoscaling_actor_pools(): + pool.scale_down(pool.current_size()) assert op.metrics.obj_store_mem_internal_inqueue == 0 assert op.metrics.obj_store_mem_internal_outqueue == 0 assert op.metrics.obj_store_mem_pending_task_inputs == 0 @@ -382,8 +378,8 @@ def test_actor_pool_resource_reporting_with_bundling(ray_start_10_cpus_shared): * data_context.target_max_block_size ) assert op.base_resource_usage() == ExecutionResources(cpu=2, gpu=0) - # All actors are idle (pending creation), therefore shouldn't need to scale up when - # submitting a new task, so incremental resource usage should be 0. + # `incremental_resource_usage` should always report 0 CPU and GPU, as + # it doesn't consider scaling-up. assert op.incremental_resource_usage() == ExecutionResources( cpu=0, gpu=0, object_store_memory=inc_obj_store_mem ) @@ -395,8 +391,6 @@ def test_actor_pool_resource_reporting_with_bundling(ray_start_10_cpus_shared): # Add inputs. for i in range(4): - # Pool is still idle while waiting for actors to start, so additional tasks - # shouldn't trigger scale-up, so incremental resource usage should still be 0. assert op.incremental_resource_usage() == ExecutionResources( cpu=0, gpu=0, object_store_memory=inc_obj_store_mem ) @@ -420,11 +414,6 @@ def test_actor_pool_resource_reporting_with_bundling(ray_start_10_cpus_shared): assert op.num_active_tasks() == 2 run_op_tasks_sync(op, only_existing=True) - # Now that both actors have started, a new task would trigger scale-up, so - inc_usage = op.incremental_resource_usage() - assert inc_usage.cpu == 1, inc_usage - assert inc_usage.gpu == 0, inc_usage - # Actors have now started and the pool is actively running tasks. assert op.current_processor_usage() == ExecutionResources(cpu=2, gpu=0) @@ -434,8 +423,9 @@ def test_actor_pool_resource_reporting_with_bundling(ray_start_10_cpus_shared): # Wait until tasks are done. run_op_tasks_sync(op) - # Work is done and the pool has been scaled down. - assert op.current_processor_usage() == ExecutionResources(cpu=0, gpu=0) + # Work is done, scale down the actor pool. + for pool in op.get_autoscaling_actor_pools(): + pool.scale_down(pool.current_size()) assert op.metrics.obj_store_mem_internal_inqueue == 0 assert op.metrics.obj_store_mem_internal_outqueue == pytest.approx(6400, rel=0.5) assert op.metrics.obj_store_mem_pending_task_inputs == 0 @@ -445,7 +435,9 @@ def test_actor_pool_resource_reporting_with_bundling(ray_start_10_cpus_shared): while op.has_next(): op.get_next() - # Work is done, pool has been scaled down, and outputs have been consumed. + # Work is done, scale down the actor pool, and outputs have been consumed. + for pool in op.get_autoscaling_actor_pools(): + pool.scale_down(pool.current_size()) assert op.current_processor_usage() == ExecutionResources(cpu=0, gpu=0) assert op.metrics.obj_store_mem_internal_inqueue == 0 assert op.metrics.obj_store_mem_internal_outqueue == 0 diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index c962a50a6b56..b48d94c79fb7 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -132,69 +132,6 @@ def dummy_all_transform(bundles: List[RefBundle]): assert op2.num_outputs_total() == 100 -@pytest.mark.parametrize("use_actors", [False, True]) -def test_map_operator_bulk(ray_start_regular_shared, use_actors): - # Create with inputs. - input_op = InputDataBuffer( - make_ref_bundles([[np.ones(1024) * i] for i in range(100)]) - ) - compute_strategy = ActorPoolStrategy(size=1) if use_actors else TaskPoolStrategy() - op = MapOperator.create( - _mul2_map_data_prcessor, - input_op=input_op, - name="TestMapper", - compute_strategy=compute_strategy, - ) - - # Feed data and block on exec. - op.start(ExecutionOptions(actor_locality_enabled=False)) - if use_actors: - # Actor will be pending after starting the operator. - assert op.progress_str() == "0 actors (1 pending) [locality off]" - assert op.internal_queue_size() == 0 - i = 0 - while input_op.has_next(): - op.add_input(input_op.get_next(), 0) - i += 1 - if use_actors: - assert op.internal_queue_size() == i - else: - assert op.internal_queue_size() == 0 - op.all_inputs_done() - - tasks = op.get_active_tasks() - while tasks: - run_op_tasks_sync(op, only_existing=True) - tasks = op.get_active_tasks() - if use_actors and tasks: - # After actor is ready (first work ref resolved), actor will remain ready - # while there is work to do. - assert op.progress_str() == "1 actors [locality off]" - - assert op.internal_queue_size() == 0 - if use_actors: - # After all work is done, actor will have been killed to free up resources.. - assert op.progress_str() == "0 actors [locality off]" - else: - assert op.progress_str() == "" - - # Check we return transformed bundles in order. - assert not op.completed() - assert np.array_equal( - _take_outputs(op), [[np.ones(1024) * i * 2] for i in range(100)] - ) - assert op.completed() - - # Check dataset stats. - stats = op.get_stats() - assert "TestMapper" in stats, stats - assert len(stats["TestMapper"]) == 100, stats - - # Check memory stats. - metrics = op.metrics.as_dict() - assert metrics["obj_store_mem_freed"] == pytest.approx(832200, 0.5), metrics - - @pytest.mark.parametrize("use_actors", [False, True]) def test_map_operator_streamed(ray_start_regular_shared, use_actors): # Create with inputs. diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index 82883568e4f9..487e19880b30 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -6,7 +6,7 @@ import pytest import ray -from ray._private.test_utils import run_string_as_driver_nonblocking, wait_for_condition +from ray._private.test_utils import run_string_as_driver_nonblocking from ray.data._internal.execution.interfaces import ( ExecutionOptions, ExecutionResources, @@ -24,7 +24,6 @@ _validate_dag, ) from ray.data._internal.execution.streaming_executor_state import ( - AutoscalingState, OpBufferQueue, OpState, _execution_allowed, @@ -58,6 +57,10 @@ def mock_resource_manager( ) +def mock_autoscaler(): + return MagicMock() + + @ray.remote def sleep(): time.sleep(999) @@ -200,48 +203,32 @@ def test_select_operator_to_run(): side_effect=lambda op: ExecutionResources(0, 0, memory_usage[op]) ) - # Test empty. - assert ( - select_operator_to_run( - topo, resource_manager, [], True, "dummy", AutoscalingState() + def _select_op_to_run(): + nonlocal topo, resource_manager + + return select_operator_to_run( + topo, resource_manager, [], mock_autoscaler(), True ) - is None - ) + + # Test empty. + assert _select_op_to_run() is None # Test backpressure based on memory_usage of each operator. topo[o1].outqueue.append(make_ref_bundle("dummy1")) memory_usage[o1] += 1 - assert ( - select_operator_to_run( - topo, resource_manager, [], True, "dummy", AutoscalingState() - ) - == o2 - ) + assert _select_op_to_run() == o2 + topo[o1].outqueue.append(make_ref_bundle("dummy2")) memory_usage[o1] += 1 - assert ( - select_operator_to_run( - topo, resource_manager, [], True, "dummy", AutoscalingState() - ) - == o2 - ) + assert _select_op_to_run() == o2 + topo[o2].outqueue.append(make_ref_bundle("dummy3")) memory_usage[o2] += 1 - assert ( - select_operator_to_run( - topo, resource_manager, [], True, "dummy", AutoscalingState() - ) - == o3 - ) + assert _select_op_to_run() == o3 # Test prioritization of nothrottle ops. o2.throttling_disabled = MagicMock(return_value=True) - assert ( - select_operator_to_run( - topo, resource_manager, [], True, "dummy", AutoscalingState() - ) - == o2 - ) + assert _select_op_to_run() == o2 def test_dispatch_next_task(): @@ -404,160 +391,6 @@ def test_execution_allowed(): ) -@pytest.mark.skip( - reason="Temporarily disable to deflake rest of test suite. Started being flaky " - "after moving to civ2? Needs further investigation to confirm." -) -def test_resource_constrained_triggers_autoscaling(monkeypatch): - RESOURCE_REQUEST_TIMEOUT = 5 - monkeypatch.setattr( - ray.data._internal.execution.autoscaling_requester, - "RESOURCE_REQUEST_TIMEOUT", - RESOURCE_REQUEST_TIMEOUT, - ) - monkeypatch.setattr( - ray.data._internal.execution.autoscaling_requester, - "PURGE_INTERVAL", - RESOURCE_REQUEST_TIMEOUT, - ) - from ray.data._internal.execution.autoscaling_requester import ( - get_or_create_autoscaling_requester_actor, - ) - - ray.shutdown() - ray.init(num_cpus=3, num_gpus=1) - - def run_execution( - execution_id: str, incremental_cpu: int = 1, autoscaling_state=None - ): - if autoscaling_state is None: - autoscaling_state = AutoscalingState() - opt = ExecutionOptions() - inputs = make_ref_bundles([[x] for x in range(20)]) - o1 = InputDataBuffer(inputs) - o2 = MapOperator.create( - make_map_transformer(lambda block: [b * -1 for b in block]), - o1, - ) - o2.num_active_tasks = MagicMock(return_value=1) - o3 = MapOperator.create( - make_map_transformer(lambda block: [b * 2 for b in block]), - o2, - ) - o3.num_active_tasks = MagicMock(return_value=1) - o4 = MapOperator.create( - make_map_transformer(lambda block: [b * 3 for b in block]), - o3, - compute_strategy=ray.data.ActorPoolStrategy(min_size=1, max_size=2), - ray_remote_args={"num_gpus": incremental_cpu}, - ) - o4.num_active_tasks = MagicMock(return_value=1) - o4.incremental_resource_usage = MagicMock( - return_value=ExecutionResources(gpu=1) - ) - topo = build_streaming_topology(o4, opt)[0] - # Make sure only two operator's inqueues has data. - topo[o2].inqueues[0].append(make_ref_bundle("dummy")) - topo[o4].inqueues[0].append(make_ref_bundle("dummy")) - resource_manager = mock_resource_manager( - global_usage=ExecutionResources(cpu=2, gpu=1, object_store_memory=1000), - global_limits=ExecutionResources.for_limits( - cpu=2, gpu=1, object_store_memory=1000 - ), - ) - selected_op = select_operator_to_run( - topo, - resource_manager, - [], - True, - execution_id, - autoscaling_state, - ) - assert selected_op is None - for op in topo: - op.shutdown() - - test_timeout = 3 - ac = get_or_create_autoscaling_requester_actor() - ray.get(ac._test_set_timeout.remote(test_timeout)) - - run_execution("1") - assert ray.get(ac._aggregate_requests.remote()) == [ - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - {"CPU": 1}, - ] - - # For the same execution_id, the later request overrides the previous one. - run_execution("1") - assert ray.get(ac._aggregate_requests.remote()) == [ - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - {"CPU": 1}, - ] - - # Having another execution, so the resource bundles expanded. - run_execution("2") - assert ray.get(ac._aggregate_requests.remote()) == [ - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - ] - - # Requesting for existing execution again, so no change in resource bundles. - run_execution("1") - assert ray.get(ac._aggregate_requests.remote()) == [ - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - ] - - # After the timeout, all requests should have been purged. - time.sleep(test_timeout + 1) - ray.get(ac._purge.remote()) - assert ray.get(ac._aggregate_requests.remote()) == [] - - # Test throttling by sending 100 requests: only one request actually - # got sent to the actor. - autoscaling_state = AutoscalingState() - for i in range(5): - run_execution("1", 1, autoscaling_state) - assert ray.get(ac._aggregate_requests.remote()) == [ - {"CPU": 1}, - {"CPU": 1}, - {"CPU": 1}, - {"GPU": 1}, - {"GPU": 1}, - {"CPU": 1}, - ] - - # Test that the resource requests will be purged after the timeout. - wait_for_condition( - lambda: ray.get(ac._aggregate_requests.remote()) == [], - timeout=RESOURCE_REQUEST_TIMEOUT * 2, - ) - - def test_select_ops_ensure_at_least_one_live_operator(): opt = ExecutionOptions() inputs = make_ref_bundles([[x] for x in range(20)]) @@ -577,40 +410,18 @@ def test_select_ops_ensure_at_least_one_live_operator(): global_usage=ExecutionResources(cpu=1), global_limits=ExecutionResources.for_limits(cpu=1), ) - assert ( - select_operator_to_run( - topo, - resource_manager, - [], - True, - "dummy", - AutoscalingState(), + + def _select_op_to_run(ensure_at_least_one_running): + nonlocal topo, resource_manager + + return select_operator_to_run( + topo, resource_manager, [], mock_autoscaler(), ensure_at_least_one_running ) - is None - ) + + assert _select_op_to_run(True) is None o1.num_active_tasks = MagicMock(return_value=0) - assert ( - select_operator_to_run( - topo, - resource_manager, - [], - True, - "dummy", - AutoscalingState(), - ) - is o3 - ) - assert ( - select_operator_to_run( - topo, - resource_manager, - [], - False, - "dummy", - AutoscalingState(), - ) - is None - ) + assert _select_op_to_run(True) is o3 + assert _select_op_to_run(False) is None def test_configure_output_locality():