Skip to content

Commit

Permalink
[PERF] Remove calls to remote_len_partition (#1660)
Browse files Browse the repository at this point in the history
This PR refactors `PartitionSet.len_of_partitions` to avoid usage of our
`remote_len_partition` Ray remote function, which has been observed to
cause problems when run on dataframes with large amounts of spilling.

A few refactors had to be performed to achieve this:

1. The `RayPartitionSet` was refactored to hold `RayMaterializedResult`
objects instead of just raw `ray.ObjectRef[Table]`.
- This allows us to access the `.metadata()` method which holds the
length of each partition.
- To access the `ray.ObjectRef[Table]`, we can use the `.partition()`
method which holds the partition
2. As part of (1), `PartitionSet.set_partition` had to be refactored to
take as input a `MaterializedResult` instead of a plain `PartitionT`
3. On the execution end, we refactored the code mainly around
`MaterializedPhysicalPlan`, which now yields
`MaterializedResult[PartitionT]` instead of just `PartitionT`, when
indicating "done" tasks.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Nov 22, 2023
1 parent 3dc440c commit ff218e7
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 115 deletions.
2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ class PhysicalPlanScheduler:
def num_partitions(self) -> int: ...
def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], is_ray_runner: bool
) -> physical_plan.MaterializedPhysicalPlan: ...
) -> physical_plan.InProgressPhysicalPlan: ...

class LogicalPlanBuilder:
"""
Expand Down
5 changes: 3 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def iter_partitions(self) -> Iterator[Union[Table, "RayObjectRef"]]:
else:
# Execute the dataframe in a streaming fashion.
context = get_context()
partitions_iter = context.runner().run_iter(self._builder)
yield from partitions_iter
results_iter = context.runner().run_iter(self._builder)
for result in results_iter:
yield result.partition()

@DataframePublicAPI
def __repr__(self) -> str:
Expand Down
51 changes: 12 additions & 39 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import sys
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import Generic

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand All @@ -26,14 +26,15 @@
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import (
MaterializedResult,
PartialPartitionMetadata,
PartitionMetadata,
PartitionT,
TableParseCSVOptions,
TableReadOptions,
)
from daft.table import Table, table_io

PartitionT = TypeVar("PartitionT")
ID_GEN = itertools.count()


Expand Down Expand Up @@ -175,28 +176,29 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None:
def done(self) -> bool:
return self._result is not None

def result(self) -> MaterializedResult[PartitionT]:
assert self._result is not None, "Cannot call .result() on a PartitionTask that is not done"
return self._result

def cancel(self) -> None:
# Currently only implemented for Ray tasks.
if self._result is not None:
self._result.cancel()
if self.done():
self.result().cancel()

def partition(self) -> PartitionT:
"""Get the PartitionT resulting from running this PartitionTask."""
assert self._result is not None
return self._result.partition()
return self.result().partition()

def partition_metadata(self) -> PartitionMetadata:
"""Get the metadata of the result partition.
(Avoids retrieving the actual partition itself if possible.)
"""
assert self._result is not None
return self._result.metadata()
return self.result().metadata()

def vpartition(self) -> Table:
"""Get the raw vPartition of the result."""
assert self._result is not None
return self._result.vpartition()
return self.result().vpartition()

def __str__(self) -> str:
return super().__str__()
Expand Down Expand Up @@ -251,35 +253,6 @@ def __repr__(self) -> str:
return super().__str__()


class MaterializedResult(Protocol[PartitionT]):
"""A protocol for accessing the result partition of a PartitionTask.
Different Runners can fill in their own implementation here.
"""

def partition(self) -> PartitionT:
"""Get the partition of this result."""
...

def vpartition(self) -> Table:
"""Get the vPartition of this result."""
...

def metadata(self) -> PartitionMetadata:
"""Get the metadata of the partition in this result."""
...

def cancel(self) -> None:
"""If possible, cancel execution of this PartitionTask."""
...

def _noop(self, _: PartitionT) -> None:
"""Implement this as a no-op.
https://peps.python.org/pep-0544/#overriding-inferred-variance-of-protocol-classes
"""
...


class Instruction(Protocol):
"""An instruction is a function to run over a list of partitions.
Expand Down
13 changes: 8 additions & 5 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@
)
from daft.expressions import ExpressionsProjection
from daft.logical.schema import Schema
from daft.runners.partitioning import PartialPartitionMetadata
from daft.runners.partitioning import (
MaterializedResult,
PartialPartitionMetadata,
PartitionT,
)

logger = logging.getLogger(__name__)

PartitionT = TypeVar("PartitionT")
T = TypeVar("T")


# A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks.
InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]]

# A PhysicalPlan that is complete and will only yield PartitionTasks or final PartitionTs.
MaterializedPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionT]]
MaterializedPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], MaterializedResult[PartitionT]]]


def _stage_id_counter():
Expand Down Expand Up @@ -108,7 +111,7 @@ def file_read(
for i in range(len(vpartition)):
file_read_step = PartitionTaskBuilder[PartitionT](
inputs=[done_task.partition()],
partial_metadatas=[done_task.partition_metadata()],
partial_metadatas=None, # Child's metadata doesn't really matter for a file read
).add_instruction(
instruction=execution_step.ReadFile(
index=i,
Expand Down Expand Up @@ -738,7 +741,7 @@ def materialize(
# Check if any inputs finished executing.
while len(materializations) > 0 and materializations[0].done():
done_task = materializations.popleft()
yield done_task.partition()
yield done_task.result()

# Materialize a single dependency.
try:
Expand Down
6 changes: 2 additions & 4 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterator, TypeVar, cast
from typing import Iterator, cast

from daft.daft import (
FileFormat,
Expand All @@ -19,11 +19,9 @@
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import PartialPartitionMetadata
from daft.runners.partitioning import PartialPartitionMetadata, PartitionT
from daft.table import Table

PartitionT = TypeVar("PartitionT")


def scan_with_tasks(
scan_tasks: list[ScanTask],
Expand Down
40 changes: 35 additions & 5 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,40 @@ def from_table(cls, table: Table) -> PartitionMetadata:
PartitionT = TypeVar("PartitionT")


class MaterializedResult(Generic[PartitionT]):
"""A protocol for accessing the result partition of a PartitionTask.
Different Runners can fill in their own implementation here.
"""

@abstractmethod
def partition(self) -> PartitionT:
"""Get the partition of this result."""
...

@abstractmethod
def vpartition(self) -> Table:
"""Get the vPartition of this result."""
...

@abstractmethod
def metadata(self) -> PartitionMetadata:
"""Get the metadata of the partition in this result."""
...

@abstractmethod
def cancel(self) -> None:
"""If possible, cancel execution of this PartitionTask."""
...

@abstractmethod
def _noop(self, _: PartitionT) -> None:
"""Implement this as a no-op.
https://peps.python.org/pep-0544/#overriding-inferred-variance-of-protocol-classes
"""
...


class PartitionSet(Generic[PartitionT]):
def _get_merged_vpartition(self) -> Table:
raise NotImplementedError()
Expand Down Expand Up @@ -126,7 +160,7 @@ def get_partition(self, idx: PartID) -> PartitionT:
raise NotImplementedError()

@abstractmethod
def set_partition(self, idx: PartID, part: PartitionT) -> None:
def set_partition(self, idx: PartID, part: MaterializedResult[PartitionT]) -> None:
raise NotImplementedError()

@abstractmethod
Expand All @@ -139,10 +173,6 @@ def has_partition(self, idx: PartID) -> bool:

@abstractmethod
def __len__(self) -> int:
return sum(self.len_of_partitions())

@abstractmethod
def len_of_partitions(self) -> list[int]:
raise NotImplementedError()

@abstractmethod
Expand Down
36 changes: 19 additions & 17 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
StorageConfig,
)
from daft.execution import physical_plan
from daft.execution.execution_step import Instruction, MaterializedResult, PartitionTask
from daft.execution.execution_step import Instruction, PartitionTask
from daft.filesystem import glob_path_with_stats
from daft.internal.gpu import cuda_device_count
from daft.logical.builder import LogicalPlanBuilder
from daft.logical.schema import Schema
from daft.runners import runner_io
from daft.runners.partitioning import (
MaterializedResult,
PartID,
PartitionCacheEntry,
PartitionMetadata,
Expand Down Expand Up @@ -52,8 +53,8 @@ def _get_merged_vpartition(self) -> Table:
def get_partition(self, idx: PartID) -> Table:
return self._partitions[idx]

def set_partition(self, idx: PartID, part: Table) -> None:
self._partitions[idx] = part
def set_partition(self, idx: PartID, part: MaterializedResult[Table]) -> None:
self._partitions[idx] = part.partition()

def delete_partition(self, idx: PartID) -> None:
del self._partitions[idx]
Expand All @@ -62,11 +63,7 @@ def has_partition(self, idx: PartID) -> bool:
return idx in self._partitions

def __len__(self) -> int:
return sum(self.len_of_partitions())

def len_of_partitions(self) -> list[int]:
partition_ids = sorted(list(self._partitions.keys()))
return [len(self._partitions[pid]) for pid in partition_ids]
return sum([len(self._partitions[pid]) for pid in self._partitions])

def num_partitions(self) -> int:
return len(self._partitions)
Expand Down Expand Up @@ -119,11 +116,11 @@ def runner_io(self) -> PyRunnerIO:
return PyRunnerIO()

def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry:
partitions = list(self.run_iter(builder))
results = list(self.run_iter(builder))

result_pset = LocalPartitionSet({})
for i, partition in enumerate(partitions):
result_pset.set_partition(i, partition)
for i, result in enumerate(results):
result_pset.set_partition(i, result)

pset_entry = self.put_partition_set_into_cache(result_pset)
return pset_entry
Expand All @@ -133,7 +130,7 @@ def run_iter(
builder: LogicalPlanBuilder,
# NOTE: PyRunner does not run any async execution, so it ignores `results_buffer_size` which is essentially 0
results_buffer_size: int | None = None,
) -> Iterator[Table]:
) -> Iterator[PyMaterializedResult]:
# Optimize the logical plan.
builder = builder.optimize()
# Finalize the logical plan and get a physical plan scheduler for translating the
Expand All @@ -147,13 +144,16 @@ def run_iter(
# Get executable tasks from planner.
tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False)
with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"):
partitions_gen = self._physical_plan_to_partitions(tasks)
yield from partitions_gen
results_gen = self._physical_plan_to_partitions(tasks)
yield from results_gen

def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]:
return self.run_iter(builder, results_buffer_size=results_buffer_size)
for result in self.run_iter(builder, results_buffer_size=results_buffer_size):
yield result.partition()

def _physical_plan_to_partitions(self, plan: physical_plan.MaterializedPhysicalPlan) -> Iterator[Table]:
def _physical_plan_to_partitions(
self, plan: physical_plan.MaterializedPhysicalPlan[Table]
) -> Iterator[PyMaterializedResult]:
inflight_tasks: dict[str, PartitionTask] = dict()
inflight_tasks_resources: dict[str, ResourceRequest] = dict()
future_to_task: dict[futures.Future, str] = dict()
Expand All @@ -171,7 +171,9 @@ def _physical_plan_to_partitions(self, plan: physical_plan.MaterializedPhysicalP
# Blocked on already dispatched tasks; await some tasks.
break

elif isinstance(next_step, Table):
elif isinstance(next_step, MaterializedResult):
assert isinstance(next_step, PyMaterializedResult)

# A final result.
yield next_step
next_step = next(plan)
Expand Down
Loading

0 comments on commit ff218e7

Please sign in to comment.