Skip to content

Commit

Permalink
[FEAT] Add RayRunner actor pool execution
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Aug 16, 2024
1 parent 1b6def2 commit b8d5fed
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 2 deletions.
156 changes: 154 additions & 2 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
IOConfig,
PyDaftExecutionConfig,
ResourceRequest,
extract_partial_stateful_udf_py,
)
from daft.datatype import DataType
from daft.execution.execution_step import (
Expand All @@ -43,6 +44,7 @@
ReduceInstruction,
ScanWithTask,
SingleOutputPartitionTask,
StatefulUDFProject,
)
from daft.filesystem import glob_path_with_stats
from daft.runners import runner_io
Expand Down Expand Up @@ -444,6 +446,8 @@ def __init__(self, max_task_backlog: int | None, use_ray_tqdm: bool) -> None:
self.active_by_df: dict[str, bool] = dict()
self.results_buffer_size_by_df: dict[str, int | None] = dict()

self._actor_pools: dict[str, RayRoundRobinActorPool] = {}

self.use_ray_tqdm = use_ray_tqdm

def next(self, result_uuid: str) -> RayMaterializedResult | StopIteration:
Expand Down Expand Up @@ -501,6 +505,24 @@ def stop_plan(self, result_uuid: str) -> None:
del self.results_by_df[result_uuid]
del self.results_buffer_size_by_df[result_uuid]

def get_actor_pool(
self,
name: str,
resource_request: ResourceRequest,
num_actors: int,
projection: ExpressionsProjection,
execution_config: PyDaftExecutionConfig,
) -> str:
actor_pool = RayRoundRobinActorPool(name, num_actors, resource_request, projection, execution_config)
self._actor_pools[name] = actor_pool
self._actor_pools[name].setup()
return name

def teardown_actor_pool(self, name: str) -> None:
if name in self._actor_pools:
self._actor_pools[name].teardown()
del self._actor_pools[name]

def _run_plan(
self,
plan_scheduler: PhysicalPlanScheduler,
Expand Down Expand Up @@ -613,7 +635,12 @@ def place_in_queue(item):
break

for task in tasks_to_dispatch:
results = _build_partitions(daft_execution_config, task)
if task.actor_pool_id is None:
results = _build_partitions(daft_execution_config, task)
else:
actor_pool = self._actor_pools.get(task.actor_pool_id)
assert actor_pool is not None, "Ray actor pool must live for as long as the tasks."
results = _build_partitions_on_actor_pool(task, actor_pool)
logger.debug("%s -> %s", task, results)
inflight_tasks[task.id()] = task
for result in results:
Expand Down Expand Up @@ -743,6 +770,119 @@ def _build_partitions(
return partitions


def _build_partitions_on_actor_pool(
task: PartitionTask[ray.ObjectRef],
actor_pool: RayRoundRobinActorPool,
) -> list[ray.ObjectRef]:
"""Run a PartitionTask on an actor pool and return the resulting list of partitions."""
[metadatas_ref, *partitions] = actor_pool.submit(task.instructions, task.partial_metadatas, task.inputs)
metadatas_accessor = PartitionMetadataAccessor(metadatas_ref)
task.set_result(
[
RayMaterializedResult(
partition=partition,
metadatas=metadatas_accessor,
metadata_idx=i,
)
for i, partition in enumerate(partitions)
]
)
return partitions


@ray.remote
class DaftRayActor:
def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_projection: ExpressionsProjection):
set_execution_config(daft_execution_config)

partial_stateful_udfs = {
name: psu
for expr in uninitialized_projection
for name, psu in extract_partial_stateful_udf_py(expr._expr).items()
}
logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys()))
self.initialized_stateful_udfs = {
name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items()
}

@ray.method(num_returns=2)
def run(
self,
uninitialized_projection: ExpressionsProjection,
partial_metadatas: list[PartitionMetadata],
*inputs: MicroPartition,
) -> list[list[PartitionMetadata] | MicroPartition]:
assert len(inputs) == 1, "DaftRayActor can only process single partitions"
assert len(partial_metadatas) == 1, "DaftRayActor can only process single partitions (and single metadata)"
part = inputs[0]
partial = partial_metadatas[0]

# Bind the ExpressionsProjection to the initialized UDFs
initialized_projection = ExpressionsProjection(
[e._bind_stateful_udfs(self.initialized_stateful_udfs) for e in uninitialized_projection]
)
new_part = part.eval_expression_list(initialized_projection)

return [
[PartitionMetadata.from_table(new_part).merge_with_partial(partial)],
new_part,
]


class RayRoundRobinActorPool:
"""Naive implementation of an ActorPool that performs round-robin task submission to the actors"""

def __init__(
self,
pool_id: str,
num_actors: int,
resource_request: ResourceRequest,
projection: ExpressionsProjection,
execution_config: PyDaftExecutionConfig,
):
self._actors: list[DaftRayActor] | None = None
self._task_idx = 0

self._execution_config = execution_config
self._num_actors = num_actors
self._resource_request_per_actor = resource_request
self._id = pool_id
self._projection = projection

def setup(self) -> None:
self._actors = [
DaftRayActor.remote(self._execution_config, self._projection) # type: ignore
for _ in range(self._num_actors)
]

def teardown(self):
assert self._actors is not None, "Must have active Ray actors on teardown"

# Delete the actors in the old pool so Ray can tear them down
old_actors = self._actors
self._actors = None
del old_actors

def submit(
self, instruction_stack: list[Instruction], partial_metadatas: list[ray.ObjectRef], inputs: list[ray.ObjectRef]
) -> list[ray.ObjectRef]:
assert self._actors is not None, "Must have active Ray actors during submission"

assert (
len(instruction_stack) == 1
), "RayRoundRobinActorPool can only handle single StatefulUDFProject instructions"
instruction = instruction_stack[0]
assert isinstance(instruction, StatefulUDFProject)
projection = instruction.projection

# Determine which actor to schedule on in a round-robin fashion
idx = self._task_idx % self._num_actors
self._task_idx += 1
actor = self._actors[idx]

return actor.run.remote(projection, partial_metadatas, *inputs)


class RayRunner(Runner[ray.ObjectRef]):
def __init__(
self,
Expand Down Expand Up @@ -881,7 +1021,19 @@ def run_iter_tables(
def actor_pool_context(
self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection
) -> Iterator[str]:
raise NotImplementedError("Actor pool for RayRunner not yet implemented")
execution_config = get_context().daft_execution_config
if self.ray_client_mode:
try:
yield ray.get(
self.scheduler_actor.get_actor_pool.remote(name, resource_request, num_actors, projection)
)
finally:
self.scheduler_actor.teardown_actor_pool.remote(name)
else:
try:
yield self.scheduler.get_actor_pool(name, resource_request, num_actors, projection, execution_config)
finally:
self.scheduler.teardown_actor_pool(name)

def _collect_into_cache(self, results_iter: Iterator[RayMaterializedResult]) -> PartitionCacheEntry:
result_pset = RayPartitionSet()
Expand Down
55 changes: 55 additions & 0 deletions tests/actor_pool/test_ray_actor_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import ray

import daft
from daft import DataType, ResourceRequest
from daft.daft import PyDaftExecutionConfig
from daft.execution.execution_step import StatefulUDFProject
from daft.expressions import ExpressionsProjection
from daft.runners.partitioning import PartialPartitionMetadata
from daft.runners.ray_runner import RayRoundRobinActorPool
from daft.table import MicroPartition


@daft.udf(return_dtype=DataType.int64())
class MyStatefulUDF:
def __init__(self):
self.state = 0

def __call__(self, x):
self.state += 1
return [i + self.state for i in x.to_pylist()]


def test_ray_actor_pool():
projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))])
pool = RayRoundRobinActorPool(
"my-pool", 1, ResourceRequest(num_cpus=1), projection, execution_config=PyDaftExecutionConfig.from_env()
)
initial_partition = ray.put(MicroPartition.from_pydict({"x": [1, 1, 1]}))
ppm = PartialPartitionMetadata(num_rows=None, size_bytes=None)
instr = StatefulUDFProject(projection=projection)
pool.setup()

result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition])
[partial_metadata, result_data] = ray.get(result)
assert len(partial_metadata) == 1
pm = partial_metadata[0]
assert isinstance(pm, PartialPartitionMetadata)
assert pm.num_rows == 3
assert result_data.to_pydict() == {"x": [2, 2, 2]}

result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition])
[partial_metadata, result_data] = ray.get(result)
assert len(partial_metadata) == 1
pm = partial_metadata[0]
assert isinstance(pm, PartialPartitionMetadata)
assert pm.num_rows == 3
assert result_data.to_pydict() == {"x": [3, 3, 3]}

result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition])
[partial_metadata, result_data] = ray.get(result)
assert len(partial_metadata) == 1
pm = partial_metadata[0]
assert isinstance(pm, PartialPartitionMetadata)
assert pm.num_rows == 3
assert result_data.to_pydict() == {"x": [4, 4, 4]}

0 comments on commit b8d5fed

Please sign in to comment.