-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse operations with different numbers of tasks #368
Conversation
when fusing operations with different numbers of tasks
@@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False): | |||
dtype=dtype, | |||
keepdims=keepdims, | |||
use_new_impl=use_new_impl, | |||
split_every=split_every, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this split_every
argument a temporary implementation detail whilst we figure out fusing heuristics? It would be nice to relate the meaning of this argument back to the discussion in #284 (as I currently don't quite understand what it means)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The split_every
argument is what we referred to as "fan-in" in #284 - the number of chunks read by one task doing the reduction step. It's the same in Dask and can be a dictionary indicating the number of chunks to read in each dimension.
I hope it is something that we can get better heuristics for (or at least good defaults) - possibly by measuring trade offs like they did in the Primula paper, see #331.
This is very exciting! |
I've generated a few plan visualizations to give an idea of what the code does. Running this cut-down quadratic means example gives the following plan using the old optimization algorithm (for The reduce rounds (for the If we use the new optimization algorithm then it will fuse the first m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)
m.visualize(
filename=tmp_path / "quad_means",
optimize_function=multiple_inputs_optimize_dag
) Notice that the number of tasks has gone down from 160 to 59, and the amount of intermediate data stored has gone down from 16.8GB to 7.7GB. The reduce rounds have 50, 5, and 1 task(s) - this is controlled by the If we now fuse the first two operations (purple boxes with rounded corners) then we get make the plan even smaller. Note in the code below we have to explicitly ask for this operation to be fused. This is using the logic from this PR which is able to fuse operations with a different number of tasks. The first operation (marked m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)
m.visualize(
filename=tmp_path / "quad_means",
optimize_function=partial(multiple_inputs_optimize_dag, always_fuse=["op-008"]),
) There are now only two reduce rounds: with 5 and 1 task(s). And a total of only 8 tasks, and 166.8MB of intermediate data - a significant saving. |
Add tests to check num_input_blocks
The mypy failures suggested to me that |
Wow that's a big difference. Thanks for writing this out @tomwhite. I was about to try and estimate the projected speedup here (assuming IO dominates the execution time), but without being able to see the number of tasks in each stage it's not trivial to reason about. (Half the number of zarr stores to be written, but not necessarily twice as fast because do |
It's definitely hard to reason about, even with the number of tasks in each stage. (BTW I have listed the number of tasks in each stage in the comment above.) The total number of stages is probably the biggest factor in how long the computation takes. I plan to run some benchmarks to measure the performance of these workloads. But I agree that it would be useful to put the number of tasks in the visualization. It does appear in the tooltip for nodes, but for some reason that doesn't appear when the SVG is embedded in another page. |
Fixes #284
This adds the necessary logic to the fuse code to fuse two operations that have a different number of tasks, but fusion must be controlled manually using the
always_fuse
argument to the optimization functions. This is fine for experimenting with the optimizer, but we'll probably want to add better heuristics about when to fuse these types of operations automatically in the future.