-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ScalarLoop in torch backend #958
base: main
Are you sure you want to change the base?
Conversation
@Ch0ronomato thanks for taking a stab, I left some comments above |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #958 +/- ##
==========================================
+ Coverage 81.90% 81.96% +0.06%
==========================================
Files 182 182
Lines 47879 47914 +35
Branches 8617 8632 +15
==========================================
+ Hits 39214 39272 +58
+ Misses 6492 6474 -18
+ Partials 2173 2168 -5
|
carry = update(*carry, *constants) | ||
return torch.stack(carry) | ||
|
||
return torch.compiler.disable(scalar_loop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do recursive=False?
@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓 |
@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr? |
If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") | ||
|
||
|
||
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) | ||
|
||
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) | ||
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.
from torch import is_tensor | ||
|
||
if is_tensor(out): | ||
return out.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will probably create conflict when one of my other PRs gets merged as an FYI.
final_inputs[i] = list(layer) | ||
|
||
# make sure we still have the same number of things | ||
assert len(final_inputs) == len(shaped_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can put these into the unit test if that's preferred now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the assert is executed every time at runtime, yes let's not do it here
torch.zeros(*input_shapes[-1]) | ||
] | ||
mock_inner_func = MagicMock() | ||
ret_value = torch.rand(2, 2).unbind(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to expected
mock_inner_func.f.return_value = ret_value | ||
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) | ||
result = elemwise_fn(*args) | ||
for actual, expected in zip(ret_value, result): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are backwards fyi
def elemwise_scalar_loop(base_fn, op, node, **kwargs): | ||
""" | ||
ScalarLoop + Elemwise is too common | ||
to not work, but @1031, vmap won't allow it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include full link instead of @1031
Elemwise._check_runtime_broadcast(node, inputs) | ||
shaped_inputs = torch.broadcast_tensors(*inputs) | ||
expected_size = shaped_inputs[0].numel() | ||
final_inputs = [s.clone() for s in shaped_inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why .clone()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be unnecessary now. We need the original number of dimensions for the outer loop. I could just grab that count instead.
for _ in range(shaped_inputs[0].dim() - 1): | ||
for i, _ in enumerate(shaped_inputs): | ||
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) | ||
final_inputs[i] = list(layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is more performant? Doing this nesting, or raveling all the inputs after broadcasting and doing a single unbind loop?
Either way, doesn't avoid the explicit broadcasting copy or does it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhhhh, this is basically like ravel you're right!
According to the torch docs, ravel only copies if needed. So there maybe cases where no coping happens
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) | ||
res = [base_fn(*args) for args in zip(*final_inputs)] | ||
|
||
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this reintroduce the original shape? Say if the Elemwise of the Scalar Loop had output shape == (5, 3, 2) ?
if len(node.outputs) == 2: | ||
return carry[0], done | ||
else: | ||
return carry, done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work?
if len(node.outputs) == 2: | |
return carry[0], done | |
else: | |
return carry, done | |
return *carry, done |
@@ -343,3 +380,44 @@ def test_pytorch_OpFromGraph(): | |||
|
|||
f = FunctionGraph([x, y, z], [out]) | |||
compare_pytorch_and_py(f, [xv, yv, zv]) | |||
|
|||
|
|||
def test_ScalarLoop_Elemwise(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there's a special condition for one or multiple carry, please test also both kinds of loop with multiple and single updates
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") | ||
|
||
|
||
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this test. Just use these shapes in the test above and let the numerical checks do its job
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay sounds good. I made this to try to lock down the implementation a bit. I also added it for understanding, does the method make sense now?
n_steps = pt.scalar("n_steps", dtype="int32") | ||
x0 = pt.vector("x0", dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a second carry, say of type tensor(shape=(7, 3, 1)
so it broadcasts with the vector x0.
This will make sure multiple carry are working and we are getting the right shape outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or just use the shapes you had in the test below, that's fine
How is unbind(0) different than |
https://discuss.pytorch.org/t/the-purpose-of-unbind/98648 It's essentially the same, maybe faster |
But if we index in the loop after raveling we don't need all the slices in memory. This is looking like a custom Elemwise with explicit broadcasting: bcasted_inputs = boradcast_arrays(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
out_shape = bcasted_inputs[0].size()
out_size = out_shape.nelem()
raveled_outputs = [torch.empty(out_size, dtype=out.dtype) for out in node.outputs]
for i in range(out_size):
core_outs = core_func(*(inp[i] for i in raveled_inputs))
if len(n_outputs) == 1:
raveled_outputs[0][i] = core_outs
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]
else:
return outputs Also note that nothing is specific to scalar loop, so it can be a (non-performant) fallback for all sorts of Elemwise |
That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method? How does your approach merge with #1032? |
Yes this can be a fallback only for registered Ops (and specifically only ScalarLoop at the time being). |
If my suggestion works it should be better than the nested unbind unless torch is really weird |
Description
Adds
ScalarLoop
for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.Related Issue
Checklist
Type of change