Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Link PhysicalOperator to its LogicalOperator #47986

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ExecutionResources,
)
from ray.data._internal.execution.interfaces.op_runtime_metrics import OpRuntimeMetrics
from ray.data._internal.logical.interfaces import Operator
from ray.data._internal.logical.interfaces import LogicalOperator, Operator
from ray.data._internal.stats import StatsDict
from ray.data.context import DataContext

Expand Down Expand Up @@ -188,6 +188,9 @@ def __init__(
self._estimated_num_output_bundles = None
self._estimated_output_num_rows = None
self._execution_completed = False
# The LogicalOperator(s) which were translated to create this PhysicalOperator.
# Set via `PhysicalOperator.set_logical_operators()`.
self._logical_operators: List[LogicalOperator] = []

def __reduce__(self):
raise ValueError("Operator is not serializable.")
Expand All @@ -205,6 +208,12 @@ def output_dependencies(self) -> List["PhysicalOperator"]:
def post_order_iter(self) -> Iterator["PhysicalOperator"]:
return super().post_order_iter() # type: ignore

def set_logical_operators(
self,
*logical_ops: LogicalOperator,
):
self._logical_operators = list(logical_ops)

@property
def target_max_block_size(self) -> Optional[int]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def _get_fused_map_operator(
ray_remote_args=ray_remote_args,
ray_remote_args_fn=ray_remote_args_fn,
)
op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators)

# Build a map logical operator to be used as a reference for further fusion.
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based
Expand Down
13 changes: 13 additions & 0 deletions python/ray/data/_internal/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,18 @@ def _plan(self, logical_op: LogicalOperator) -> PhysicalOperator:
f"Found unknown logical operator during planning: {logical_op}"
)

# Traverse up the DAG, and set the mapping from physical to logical operators.
# At this point, all physical operators without logical operators set
# must have been created by the current logical operator.
queue = [physical_op]
while queue:
curr_physical_op = queue.pop()
# Once we find an operator with a logical operator set, we can stop.
if curr_physical_op._logical_operators:
break

curr_physical_op.set_logical_operators(logical_op)
queue.extend(physical_op.input_dependencies)

self._physical_op_to_logical_op[physical_op] = logical_op
return physical_op
41 changes: 35 additions & 6 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def test_read_operator(ray_start_regular_shared):
physical_op.actual_target_max_block_size
== DataContext.get_current().target_max_block_size
)
# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]
assert physical_op.input_dependencies[0]._logical_operators == [op]


def test_read_operator_emits_warning_for_large_read_tasks():
Expand Down Expand Up @@ -182,6 +185,9 @@ def test_from_operators(ray_start_regular_shared):
assert isinstance(physical_op, InputDataBuffer)
assert len(physical_op.input_dependencies) == 0

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


def test_from_items_e2e(ray_start_regular_shared):
data = ["Hello", "World"]
Expand Down Expand Up @@ -253,6 +259,9 @@ def test_map_batches_operator(ray_start_regular_shared):
assert len(physical_op.input_dependencies) == 1
assert isinstance(physical_op.input_dependencies[0], MapOperator)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


def test_map_batches_e2e(ray_start_regular_shared):
ds = ray.data.range(5)
Expand Down Expand Up @@ -393,6 +402,9 @@ def test_random_shuffle_operator(ray_start_regular_shared):
== DataContext.get_current().target_shuffle_max_block_size
)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


def test_random_shuffle_e2e(ray_start_regular_shared, use_push_based_shuffle):
ds = ray.data.range(12, override_num_blocks=4)
Expand Down Expand Up @@ -430,6 +442,9 @@ def test_repartition_operator(ray_start_regular_shared, shuffle):
== DataContext.get_current().target_max_block_size
)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


@pytest.mark.parametrize(
"shuffle",
Expand Down Expand Up @@ -506,6 +521,9 @@ def test_union_operator(ray_start_regular_shared, preserve_order):
== DataContext.get_current().target_max_block_size
)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [union_op]


@pytest.mark.parametrize("preserve_order", (True, False))
def test_union_e2e(ray_start_regular_shared, preserve_order):
Expand Down Expand Up @@ -578,22 +596,23 @@ def test_read_map_batches_operator_fusion(ray_start_regular_shared):
physical_op.actual_target_max_block_size
== DataContext.get_current().target_max_block_size
)
assert physical_op._logical_operators == [read_op, op]


def test_read_map_chain_operator_fusion(ray_start_regular_shared):
# Test that a chain of different map operators are fused.
planner = Planner()
read_op = get_parquet_read_logical_op(parallelism=1)
op = MapRows(read_op, lambda x: x)
op = MapBatches(op, lambda x: x)
op = FlatMap(op, lambda x: x)
op = Filter(op, lambda x: x)
logical_plan = LogicalPlan(op)
map1 = MapRows(read_op, lambda x: x)
map2 = MapBatches(map1, lambda x: x)
map3 = FlatMap(map2, lambda x: x)
map4 = Filter(map3, lambda x: x)
logical_plan = LogicalPlan(map4)
physical_plan = planner.plan(logical_plan)
physical_plan = PhysicalOptimizer().optimize(physical_plan)
physical_op = physical_plan.dag

assert op.name == "Filter(<lambda>)"
assert map4.name == "Filter(<lambda>)"
assert (
physical_op.name == "ReadParquet->Map(<lambda>)->MapBatches(<lambda>)"
"->FlatMap(<lambda>)->Filter(<lambda>)"
Expand All @@ -605,6 +624,7 @@ def test_read_map_chain_operator_fusion(ray_start_regular_shared):
physical_op.actual_target_max_block_size
== DataContext.get_current().target_max_block_size
)
assert physical_op._logical_operators == [read_op, map1, map2, map3, map4]


def test_read_map_batches_operator_fusion_compatible_remote_args(
Expand Down Expand Up @@ -1009,6 +1029,9 @@ def test_write_operator(ray_start_regular_shared, tmp_path):
assert len(physical_op.input_dependencies) == 1
assert isinstance(physical_op.input_dependencies[0], MapOperator)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


def test_sort_operator(
ray_start_regular_shared,
Expand Down Expand Up @@ -1105,6 +1128,9 @@ def test_aggregate_operator(ray_start_regular_shared):
== DataContext.get_current().target_shuffle_max_block_size
)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


def test_aggregate_e2e(ray_start_regular_shared, use_push_based_shuffle):
ds = ray.data.range(100, override_num_blocks=4)
Expand Down Expand Up @@ -1171,6 +1197,9 @@ def test_zip_operator(ray_start_regular_shared):
== DataContext.get_current().target_max_block_size
)

# Check that the linked logical operator is the same the input op.
assert physical_op._logical_operators == [op]


@pytest.mark.parametrize(
"num_blocks1,num_blocks2",
Expand Down
Loading