Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarLoop in torch backend #958

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

Ch0ronomato
Copy link
Contributor

Description

Adds ScalarLoop for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Aug 3, 2024
@ricardoV94
Copy link
Member

@Ch0ronomato thanks for taking a stab, I left some comments above

Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 88.46154% with 6 lines in your changes missing coverage. Please review.

Project coverage is 81.96%. Comparing base (a377c22) to head (920f5a4).
Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/scalar.py 84.00% 2 Missing and 2 partials ⚠️
pytensor/link/pytorch/linker.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #958      +/-   ##
==========================================
+ Coverage   81.90%   81.96%   +0.06%     
==========================================
  Files         182      182              
  Lines       47879    47914      +35     
  Branches     8617     8632      +15     
==========================================
+ Hits        39214    39272      +58     
+ Misses       6492     6474      -18     
+ Partials     2173     2168       -5     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/elemwise.py 74.13% <100.00%> (+5.38%) ⬆️
pytensor/link/pytorch/linker.py 91.66% <50.00%> (-8.34%) ⬇️
pytensor/link/pytorch/dispatch/scalar.py 72.91% <84.00%> (+12.04%) ⬆️

... and 17 files with indirect coverage changes

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Add torch scalar loop Implement ScalarLoop in torch backend Sep 1, 2024
carry = update(*carry, *constants)
return torch.stack(carry)

return torch.compiler.disable(scalar_loop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do recursive=False?

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

@ricardoV94
Copy link
Member

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)

pytensor/link/pytorch/dispatch/elemwise.py Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")


@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.

np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))

expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.

from torch import is_tensor

if is_tensor(out):
return out.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

final_inputs[i] = list(layer)

# make sure we still have the same number of things
assert len(final_inputs) == len(shaped_inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can put these into the unit test if that's preferred now.

Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the assert is executed every time at runtime, yes let's not do it here

torch.zeros(*input_shapes[-1])
]
mock_inner_func = MagicMock()
ret_value = torch.rand(2, 2).unbind(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to expected

mock_inner_func.f.return_value = ret_value
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
result = elemwise_fn(*args)
for actual, expected in zip(ret_value, result):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are backwards fyi

def elemwise_scalar_loop(base_fn, op, node, **kwargs):
"""
ScalarLoop + Elemwise is too common
to not work, but @1031, vmap won't allow it.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include full link instead of @1031

Elemwise._check_runtime_broadcast(node, inputs)
shaped_inputs = torch.broadcast_tensors(*inputs)
expected_size = shaped_inputs[0].numel()
final_inputs = [s.clone() for s in shaped_inputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why .clone()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be unnecessary now. We need the original number of dimensions for the outer loop. I could just grab that count instead.

Comment on lines +193 to +196
for _ in range(shaped_inputs[0].dim() - 1):
for i, _ in enumerate(shaped_inputs):
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
final_inputs[i] = list(layer)
Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is more performant? Doing this nesting, or raveling all the inputs after broadcasting and doing a single unbind loop?

Either way, doesn't avoid the explicit broadcasting copy or does it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhhhh, this is basically like ravel you're right!

According to the torch docs, ravel only copies if needed. So there maybe cases where no coping happens

assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
res = [base_fn(*args) for args in zip(*final_inputs)]

return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this reintroduce the original shape? Say if the Elemwise of the Scalar Loop had output shape == (5, 3, 2) ?

Comment on lines +62 to +65
if len(node.outputs) == 2:
return carry[0], done
else:
return carry, done
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
if len(node.outputs) == 2:
return carry[0], done
else:
return carry, done
return *carry, done

@@ -343,3 +380,44 @@ def test_pytorch_OpFromGraph():

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])


def test_ScalarLoop_Elemwise():
Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's a special condition for one or multiple carry, please test also both kinds of loop with multiple and single updates

torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")


@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this test. Just use these shapes in the test above and let the numerical checks do its job

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sounds good. I made this to try to lock down the implementation a bit. I also added it for understanding, does the method make sense now?

Comment on lines +394 to +395
n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
Copy link
Member

@ricardoV94 ricardoV94 Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a second carry, say of type tensor(shape=(7, 3, 1) so it broadcasts with the vector x0.

This will make sure multiple carry are working and we are getting the right shape outputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just use the shapes you had in the test below, that's fine

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 12, 2024

How is unbind(0) different than [x[i] for i in x.size()[0]]?

@Ch0ronomato
Copy link
Contributor Author

How is unbind(0) different than [x[i] for i in x.size()[0]]?

https://discuss.pytorch.org/t/the-purpose-of-unbind/98648

It's essentially the same, maybe faster

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 12, 2024

How is unbind(0) different than [x[i] for i in x.size()[0]]?

https://discuss.pytorch.org/t/the-purpose-of-unbind/98648

It's essentially the same, maybe faster

But if we index in the loop after raveling we don't need all the slices in memory. This is looking like a custom Elemwise with explicit broadcasting:

bcasted_inputs = boradcast_arrays(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]

out_shape = bcasted_inputs[0].size()
out_size = out_shape.nelem()
raveled_outputs = [torch.empty(out_size, dtype=out.dtype) for out in node.outputs]

for i in range(out_size):
  core_outs = core_func(*(inp[i] for i in raveled_inputs))
  if len(n_outputs) == 1:
    raveled_outputs[0][i] = core_outs
  else:
    for o in range(n_outputs):
      raveled_outputs[o][i] = core_outs[o]

outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
  return outputs[0]
else:
  return outputs

Also note that nothing is specific to scalar loop, so it can be a (non-performant) fallback for all sorts of Elemwise

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Nov 12, 2024

That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method? How does your approach merge with #1032?

@ricardoV94
Copy link
Member

That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method?

Yes this can be a fallback only for registered Ops (and specifically only ScalarLoop at the time being).

@ricardoV94
Copy link
Member

If my suggestion works it should be better than the nested unbind unless torch is really weird

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants