diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index a7dd27f355..a0aa7dc31e 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1,12 +1,13 @@ from __future__ import annotations import threading +import time import uuid from collections import defaultdict from dataclasses import dataclass from datetime import datetime from queue import Queue -from typing import TYPE_CHECKING, Any, Iterable, Iterator +from typing import TYPE_CHECKING, Any, Generator, Iterable, Iterator import pyarrow as pa from loguru import logger @@ -368,6 +369,28 @@ def get_meta(partition: Table) -> PartitionMetadata: return PartitionMetadata.from_table(partition) +def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: + """Helper that gets the number of CPUs from Ray + + Used as a generator as it provides a guard against calling ray.cluster_resources() + more than once per `ttl_seconds`. + + Example: + >>> p = _ray_num_cpus_provider() + >>> next(p) + """ + last_checked_time = time.time() + last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + while True: + currtime = time.time() + if currtime - last_checked_time < ttl_seconds: + yield last_num_cpus_queried + else: + last_checked_time = currtime + last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + yield last_num_cpus_queried + + class Scheduler: def __init__(self, max_task_backlog: int | None) -> None: """ @@ -434,15 +457,11 @@ def _run_plan( # Get executable tasks from plan scheduler. tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) - # Note: For autoscaling clusters, we will probably want to query cores dynamically. - # Keep in mind this call takes about 0.3ms. - cores = int(ray.cluster_resources()["CPU"]) - self.reserved_cores - - max_inflight_tasks = cores + self.max_task_backlog - inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict() inflight_ref_to_task: dict[ray.ObjectRef, str] = dict() + num_cpus_provider = _ray_num_cpus_provider() + start = datetime.now() profile_filename = ( f"profile_RayRunner.run()_" @@ -456,6 +475,8 @@ def _run_plan( while True: # Loop: Dispatch (get tasks -> batch dispatch). tasks_to_dispatch: list[PartitionTask] = [] + cores: int = next(num_cpus_provider) - self.reserved_cores + max_inflight_tasks = cores + self.max_task_backlog dispatches_allowed = max_inflight_tasks - len(inflight_tasks) dispatches_allowed = min(cores, dispatches_allowed)