Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Materialize arrays optimizer bugfix #564

Merged
merged 4 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def compute(
optimize_graph=optimize_graph,
optimize_function=optimize_function,
resume=resume,
array_names=tuple(a.name for a in arrays),
spec=spec,
**kwargs,
)
Expand Down Expand Up @@ -335,6 +336,7 @@ def visualize(
optimize_graph=optimize_graph,
optimize_function=optimize_function,
show_hidden=show_hidden,
array_names=tuple(a.name for a in arrays),
)


Expand Down
98 changes: 64 additions & 34 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)


def simple_optimize_dag(dag):
def simple_optimize_dag(dag, array_names=None):
"""Apply map blocks fusion."""

# note there is no need to prune the dag, since the way it is built
Expand All @@ -36,8 +36,12 @@ def can_fuse(n):
if dag.in_degree(op2) != 1:
return False

# if input is used by another node then don't fuse
# if input is one of the arrays being computed then don't fuse
op2_input = next(dag.predecessors(op2))
if array_names is not None and op2_input in array_names:
return False

# if input is used by another node then don't fuse
if dag.out_degree(op2_input) != 1:
return False

Expand Down Expand Up @@ -99,6 +103,23 @@ def predecessor_ops(dag, name):
yield pre_list[0]


def predecessor_ops_and_arrays(dag, name):
# returns op predecessors, the arrays that they produce (only one since we don't support multiple outputs yet),
# and a flag indicating if the op can be fused with each predecessor, taking into account the number of dependents for the array
nodes = dict(dag.nodes(data=True))
for input in nodes[name]["primitive_op"].source_array_names:
pre_list = list(predecessors_unordered(dag, input))
assert len(pre_list) == 1 # each array is produced by a single op
pre = pre_list[0]
can_fuse = is_primitive_op(nodes[pre]) and out_degree_unique(dag, input) == 1
yield pre, input, can_fuse


def out_degree_unique(dag, name):
"""Returns number of unique out edges"""
return len(set(post for _, post in dag.out_edges(name)))


def is_primitive_op(node_dict):
"""Return True if a node is a primitive op"""
return "primitive_op" in node_dict
Expand Down Expand Up @@ -126,6 +147,7 @@ def can_fuse_predecessors(
dag,
name,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
Expand All @@ -142,10 +164,25 @@ def can_fuse_predecessors(
return False

# if no predecessor ops can be fused then there is nothing to fuse
if all(not is_primitive_op(nodes[pre]) for pre in predecessor_ops(dag, name)):
# (this may be because predecessor ops produce arrays with multiple dependents)
if all(not can_fuse for _, _, can_fuse in predecessor_ops_and_arrays(dag, name)):
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
return False

# if a predecessor op produces one of the arrays being computed, then don't fuse
if array_names is not None:
predecessor_array_names = set(
array_name for _, array_name, _ in predecessor_ops_and_arrays(dag, name)
)
array_names_intersect = set(array_names) & predecessor_array_names
if len(array_names_intersect) > 0:
logger.debug(
"can't fuse %s since predecessor ops produce one or more arrays being computed %s",
name,
array_names_intersect,
)
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
Expand All @@ -158,8 +195,8 @@ def can_fuse_predecessors(
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
total_source_arrays = sum(
num_source_arrays(dag, pre) if is_primitive_op(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
num_source_arrays(dag, pre) if can_fuse else 1
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
)
if total_source_arrays > max_total_source_arrays:
logger.debug(
Expand All @@ -172,8 +209,8 @@ def can_fuse_predecessors(

predecessor_primitive_ops = [
nodes[pre]["primitive_op"]
for pre in predecessor_ops(dag, name)
if is_primitive_op(nodes[pre])
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
if can_fuse
]
return can_fuse_multiple_primitive_ops(
name,
Expand All @@ -187,6 +224,7 @@ def fuse_predecessors(
dag,
name,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
Expand All @@ -198,6 +236,7 @@ def fuse_predecessors(
if not can_fuse_predecessors(
dag,
name,
array_names=array_names,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
always_fuse=always_fuse,
Expand All @@ -211,8 +250,8 @@ def fuse_predecessors(

# if a predecessor has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_primitive_op(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
nodes[pre]["primitive_op"] if can_fuse else None
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
]

fused_primitive_op = fuse_multiple(primitive_op, *predecessor_primitive_ops)
Expand All @@ -224,35 +263,23 @@ def fuse_predecessors(
fused_nodes[name]["pipeline"] = fused_primitive_op.pipeline

# re-wire dag to remove predecessor nodes that have been fused

# 1. update edges to change inputs
for input in predecessors_unordered(dag, name):
pre = next(predecessors_unordered(dag, input))
if not is_primitive_op(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
fused_dag.remove_edge(input, name)
for pre in predecessor_ops(dag, name):
if not is_primitive_op(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
for input in predecessors_unordered(dag, pre):
fused_dag.add_edge(input, name)

# 2. remove predecessor nodes with no successors
# (ones with successors are needed by other nodes)
for input in predecessors_unordered(dag, name):
if fused_dag.out_degree(input) == 0:
for pre in list(predecessors_unordered(fused_dag, input)):
for pre, input, can_fuse in predecessor_ops_and_arrays(dag, name):
if can_fuse:
# check if already removed for repeated arguments
if input in fused_dag:
fused_dag.remove_node(input)
if pre in fused_dag:
fused_dag.remove_node(pre)
fused_dag.remove_node(input)
for pre_input in predecessors_unordered(dag, pre):
fused_dag.add_edge(pre_input, name)

return fused_dag


def multiple_inputs_optimize_dag(
dag,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
Expand All @@ -265,6 +292,7 @@ def multiple_inputs_optimize_dag(
dag = fuse_predecessors(
dag,
name,
array_names=array_names,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
always_fuse=always_fuse,
Expand All @@ -273,18 +301,20 @@ def multiple_inputs_optimize_dag(
return dag


def fuse_all_optimize_dag(dag):
def fuse_all_optimize_dag(dag, array_names=None):
"""Force all operations to be fused."""
dag = dag.copy()
always_fuse = [op for op in dag.nodes() if op.startswith("op-")]
return multiple_inputs_optimize_dag(dag, always_fuse=always_fuse)
return multiple_inputs_optimize_dag(
dag, array_names=array_names, always_fuse=always_fuse
)


def fuse_only_optimize_dag(dag, *, only_fuse=None):
def fuse_only_optimize_dag(dag, *, array_names=None, only_fuse=None):
"""Force only specified operations to be fused, all others will be left even if they are suitable for fusion."""
dag = dag.copy()
always_fuse = only_fuse
never_fuse = set(op for op in dag.nodes() if op.startswith("op-")) - set(only_fuse)
return multiple_inputs_optimize_dag(
dag, always_fuse=always_fuse, never_fuse=never_fuse
dag, array_names=array_names, always_fuse=always_fuse, never_fuse=never_fuse
)
16 changes: 11 additions & 5 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime
from functools import lru_cache
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import networkx as nx
import zarr
Expand Down Expand Up @@ -154,10 +154,11 @@ def arrays_to_plan(cls, *arrays):
def optimize(
self,
optimize_function: Optional[Callable[..., nx.MultiDiGraph]] = None,
array_names: Optional[Tuple[str]] = None,
):
if optimize_function is None:
optimize_function = multiple_inputs_optimize_dag
dag = optimize_function(self.dag)
dag = optimize_function(self.dag, array_names=array_names)
return Plan(dag)

def _create_lazy_zarr_arrays(self, dag):
Expand Down Expand Up @@ -243,8 +244,9 @@ def _finalize(
optimize_graph: bool = True,
optimize_function=None,
compile_function: Optional[Decorator] = None,
array_names=None,
) -> "FinalizedPlan":
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
Expand All @@ -260,11 +262,12 @@ def execute(
optimize_function=None,
compile_function=None,
resume=None,
array_names=None,
spec=None,
**kwargs,
):
finalized_plan = self._finalize(
optimize_graph, optimize_function, compile_function
optimize_graph, optimize_function, compile_function, array_names=array_names
)
dag = finalized_plan.dag

Expand Down Expand Up @@ -293,8 +296,11 @@ def visualize(
optimize_graph=True,
optimize_function=None,
show_hidden=False,
array_names=None,
):
finalized_plan = self._finalize(optimize_graph, optimize_function)
finalized_plan = self._finalize(
optimize_graph, optimize_function, array_names=array_names
)
dag = finalized_plan.dag
dag = dag.copy() # make a copy since we mutate the DAG below

Expand Down
61 changes: 42 additions & 19 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,34 @@ def test_fusion(spec, opt_fn):
)


@pytest.mark.parametrize(
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
)
def test_fusion_compute_multiple(spec, opt_fn):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.negative(a)
c = xp.astype(b, np.float32)
d = xp.negative(c)

# if we compute c and d then both have to be materialized
num_created_arrays = 2 # c, d
task_counter = TaskCounter()
cubed.visualize(c, d, optimize_function=opt_fn)
c_result, d_result = cubed.compute(
c, d, optimize_function=opt_fn, callbacks=[task_counter]
)
assert task_counter.value == num_created_arrays + 8

assert_array_equal(
c_result,
np.array([[-1, -2, -3], [-4, -5, -6], [-7, -8, -9]]).astype(np.float32),
)
assert_array_equal(
d_result,
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32),
)


@pytest.mark.parametrize(
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
)
Expand Down Expand Up @@ -170,7 +198,7 @@ def test_custom_optimize_function(spec):
< num_tasks_with_no_optimization
)

def custom_optimize_function(dag):
def custom_optimize_function(dag, array_names=None):
# leave DAG unchanged
return dag

Expand Down Expand Up @@ -448,9 +476,9 @@ def test_fuse_diamond(spec):
# from https://github.com/cubed-dev/cubed/issues/126
#
# a -> a
# | /|
# b b |
# /| \|
# | |
# b b
# /|
# c | d
# \|
# d
Expand All @@ -469,7 +497,7 @@ def test_fuse_mixed_levels_and_diamond(spec):
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (d,))
add_placeholder_op(expected_fused_dag, (b, b), (d,))
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1)
Expand Down Expand Up @@ -535,35 +563,30 @@ def test_fuse_repeated_argument(spec):
assert_array_equal(result, -2 * np.ones((2, 2)))


# other dependents
# other dependents - no optimization is made in this case (cf previously)
#
# a -> a
# | / \
# b c b
# / \ |
# c d d
# a
# |
# b
# / \
# c d
#
def test_fuse_other_dependents(spec):
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
b = xp.negative(a)
c = xp.negative(b)
d = xp.negative(b)

# only fuse c; leave d unfused
# try to fuse c; leave d unfused
opt_fn = fuse_one_level(c)

# note multi-arg forms of visualize and compute below
cubed.visualize(c, d, optimize_function=opt_fn)

# check structure of optimized dag
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (b,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
add_placeholder_op(expected_fused_dag, (b,), (d,))
# optimization does nothing
plan = arrays_to_plan(c, d)
optimized_dag = plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert structurally_equivalent(optimized_dag, plan.dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
assert get_num_input_blocks(optimized_dag, c.name) == (1,)

Expand Down
Loading