Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse operations with different numbers of tasks #368

Merged
merged 8 commits into from
Feb 5, 2024

Conversation

tomwhite
Copy link
Member

Fixes #284

This adds the necessary logic to the fuse code to fuse two operations that have a different number of tasks, but fusion must be controlled manually using the always_fuse argument to the optimization functions. This is fine for experimenting with the optimizer, but we'll probably want to add better heuristics about when to fuse these types of operations automatically in the future.

@@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this split_every argument a temporary implementation detail whilst we figure out fusing heuristics? It would be nice to relate the meaning of this argument back to the discussion in #284 (as I currently don't quite understand what it means)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split_every argument is what we referred to as "fan-in" in #284 - the number of chunks read by one task doing the reduction step. It's the same in Dask and can be a dictionary indicating the number of chunks to read in each dimension.

I hope it is something that we can get better heuristics for (or at least good defaults) - possibly by measuring trade offs like they did in the Primula paper, see #331.

@TomNicholas
Copy link
Member

This is very exciting!

@tomwhite
Copy link
Member Author

I've generated a few plan visualizations to give an idea of what the code does.

Running this cut-down quadratic means example

https://github.com/tomwhite/cubed/blob/1762d3c685e7a1333c2809d2d9e419ff6508c5ce/cubed/tests/test_core.py#L527-L537

gives the following plan using the old optimization algorithm (for t_length=500):

quad_means_current

The reduce rounds (for the mean operations) have 50, 4, and 1 task(s).

If we use the new optimization algorithm then it will fuse the first mean operation that has two inputs into one:

    m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

    m.visualize(
        filename=tmp_path / "quad_means",
        optimize_function=multiple_inputs_optimize_dag
    )

quad_means_fuse_multiple_inputs

Notice that the number of tasks has gone down from 160 to 59, and the amount of intermediate data stored has gone down from 16.8GB to 7.7GB. The reduce rounds have 50, 5, and 1 task(s) - this is controlled by the split_every=10 argument.

If we now fuse the first two operations (purple boxes with rounded corners) then we get make the plan even smaller. Note in the code below we have to explicitly ask for this operation to be fused.

This is using the logic from this PR which is able to fuse operations with a different number of tasks. The first operation (marked __mul__) has 50 tasks, and the second (mean) has 5.

    m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

    m.visualize(
        filename=tmp_path / "quad_means",
        optimize_function=partial(multiple_inputs_optimize_dag, always_fuse=["op-008"]),
    )

quad_means_fuse_diff_num_tasks

There are now only two reduce rounds: with 5 and 1 task(s). And a total of only 8 tasks, and 166.8MB of intermediate data - a significant saving.

Add tests to check num_input_blocks
@tomwhite
Copy link
Member Author

The mypy failures suggested to me that num_input_blocks belongs to BlockwiseSpec, not PrimitiveOperation, since it doesn't apply to a rechunk operation, for example. This change means that num_input_blocks also sits next to function_nargs, which it is related to, so that seems better.

@TomNicholas
Copy link
Member

TomNicholas commented Feb 1, 2024

Wow that's a big difference. Thanks for writing this out @tomwhite.

I was about to try and estimate the projected speedup here (assuming IO dominates the execution time), but without being able to see the number of tasks in each stage it's not trivial to reason about. (Half the number of zarr stores to be written, but not necessarily twice as fast because do array-008 and array-007 have the same amount of parallelism that they did before optimization?) Not seeing this information in the visualization also made coiled/feedback#271 (comment) less clear - perhaps this information should be included in the graph visualization itself?

@tomwhite
Copy link
Member Author

tomwhite commented Feb 1, 2024

It's definitely hard to reason about, even with the number of tasks in each stage. (BTW I have listed the number of tasks in each stage in the comment above.) The total number of stages is probably the biggest factor in how long the computation takes. I plan to run some benchmarks to measure the performance of these workloads.

But I agree that it would be useful to put the number of tasks in the visualization. It does appear in the tooltip for nodes, but for some reason that doesn't appear when the SVG is embedded in another page.

@tomwhite tomwhite merged commit 8831b94 into main Feb 5, 2024
7 checks passed
@tomwhite tomwhite deleted the fuse-different-num-tasks branch February 5, 2024 08:23
@tomwhite tomwhite mentioned this pull request Mar 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fuse pipelines with different numbers of tasks
2 participants