Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add einsum #722

Merged
merged 2 commits into from
Aug 4, 2024
Merged

Add einsum #722

merged 2 commits into from
Aug 4, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Apr 19, 2024

Description

TODO:

  • Support ellipsis (related to contract_path with ellipsis fails when shapes=True dgasmith/opt_einsum#235)
  • Exclude broadcastable dims for better perf (JAX does the same)
  • Handle missing static shape information (default to left to right contraction?)
    • Add rewrite for optimizing Einsum Ops when all inputs have known static shapes
    • Add rewrite for inlining optimized Einsum
  • Get rid of Blockwise Reshape
    • Fix lingering infinite rewriting bug
  • Decide on providing optimize kwarg
  • Appease Mypy
  • Better docstrings (@jessegrabowski self-assigned)
  • Fix failing tests (@ricardoV94 self-assigned)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@zaxtax
Copy link
Contributor

zaxtax commented Apr 20, 2024

Are the current tests failing suppose to fail?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Apr 20, 2024

Looks like it's related to the changes to the __len__ method in variables.py. I'd suggest just reverting the change unless we really need it. It's a pretty low-level thing that would need a bit of work to figure out everything implicated.

Copy link

codecov bot commented Apr 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.82%. Comparing base (28d9d4d) to head (e262999).
Report is 1 commits behind head on main.

Current head e262999 differs from pull request most recent head 3fe3257

Please upload reports for the commit 3fe3257 to get more accurate results.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #722      +/-   ##
==========================================
- Coverage   80.89%   80.82%   -0.07%     
==========================================
  Files         169      164       -5     
  Lines       46977    46844     -133     
  Branches    11478    11457      -21     
==========================================
- Hits        38000    37862     -138     
+ Misses       6767     6734      -33     
- Partials     2210     2248      +38     
Files Coverage Δ
pytensor/compile/builders.py 77.45% <100.00%> (-10.98%) ⬇️
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/einsum.py 100.00% <100.00%> (ø)
pytensor/tensor/einsum.py 100.00% <100.00%> (ø)

... and 55 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the einsum branch 3 times, most recently from dd08faa to 180ef9d Compare May 7, 2024 17:02
@ricardoV94
Copy link
Member

ricardoV94 commented May 7, 2024

All cases except those requiring tensordot with batch dims not on the left are passing

We may need more tests soon enough

@zaxtax
Copy link
Contributor

zaxtax commented May 7, 2024 via email

@ricardoV94
Copy link
Member

ricardoV94 commented May 7, 2024

Can we reuse how numpy implements tensordot?

We already do that, but numpy doesn't have batched tensordot (except of course through einsum), but we already have batched tensordot working in this PR just not with arbitrary batch axis. Should just need some extra transposes to get the job done

@ricardoV94 ricardoV94 changed the title Add pt.einsum Add einsum May 8, 2024
@ricardoV94
Copy link
Member

Huhu convolutions via einsum work :D

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 23, 2024

Now Einsum also works with inputs with unknown static shape (unoptimized ofc). We can add a rewrite for when such Op is found with inputs that now have static shapes (this can be quite relevant in PyMC, when users use freeze_rv_and_dims on a model with mutable coords)

@ricardoV94 ricardoV94 force-pushed the einsum branch 2 times, most recently from 2e95049 to ca8bf54 Compare July 4, 2024 17:19
@ricardoV94
Copy link
Member

The ellipsis case is failing due to a bug in opt_einsum: dgasmith/opt_einsum#235

@ricardoV94
Copy link
Member

Not sure what you mean @zaxtax, I'm talking about allowing the "optimize" kwarg like there is in numpy, which defines what kind of optimization to do: optimize{bool, list, tuple, ‘greedy’, ‘optimal’}, users can pass their custom contraction path as well.

If users pass contraction_path, we don't need to know static shapes. If users set to greedy/optimal (optimal should be default), we need to know. But we may find them later only. If they don't want optimize, then we don't need to obviously

@zaxtax
Copy link
Contributor

zaxtax commented Jul 11, 2024 via email

@ricardoV94
Copy link
Member

Some unrelated jax test failing, probably something that changed in a recent release? https://github.com/pymc-devs/pytensor/actions/runs/10161294793/job/28099514375?pr=722#step:6:778

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 30, 2024

Yes, I am also looking at this now. It's a jax bug that can be recreated easily:

import jax
jax.jit(jax.numpy.tri)(3, 3, 0)

We can ignore it. Looks like their _canonicalize_axis function is underflowing (something like np.uint32(-1) )

@ricardoV94 ricardoV94 force-pushed the einsum branch 3 times, most recently from 459bb77 to 48c663a Compare July 30, 2024 11:36
@ricardoV94
Copy link
Member

I think I fixed the tests (not the JAX one) and appeased mypy. @jessegrabowski docstrings extensions are left to you

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 30, 2024

Stopped force-pushing if you want to take over

@jessegrabowski
Copy link
Member Author

Opened an issue here: jax-ml/jax#22751

I'll hit the docstrings ASAP if that's all that's holding this up

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 30, 2024

Opened an issue here: jax-ml/jax#22751

Great let's just mark it as xfail then

@ricardoV94 ricardoV94 marked this pull request as ready for review July 30, 2024 17:14
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 31, 2024

First pass on docstrings. Working on the doctests revealed two things:

  1. Our implementation of _delta does not agree with that of JAX:
from jax._src.lax.lax import _delta as jax_delta
from pytensor.tensor.einsum import _delta as pt_delta
jax_delta(int, (3, 3, 3), (0,1))

Array([[[1, 1, 1],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [1, 1, 1],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [1, 1, 1]]], dtype=int32)


pt_delta((3,3,3), (0,1)).astype(int).eval()
array([[[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]],

       [[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]],

       [[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]]])

We seem to always output the axes=(1,2) case, regardless of what the requested axes were.

  1. Our _general_dot function is discarding shape information somewhere, which doesn't seem right:
        import pytensor.tensor as pt
        from pytensor.tensor.einsum import _general_dot
        A = pt.tensor(shape = (3, 4, 5))
        B = pt.tensor(shape = (3, 5, 2))

        result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
        print(result.type.shape)

       (3, None, None)

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 31, 2024

Static shape is optional, not a requirement. In our case it probably has to do with the reshape introduced by tensordot and/or Blockwise which doesn't do any special shape inference shape (static or at rewrite) for core shapes.

That's something we probably want to address for Blockwise in the Numba backend

@jessegrabowski
Copy link
Member Author

I understand it's optional, but it also shouldn't be discarded if available no?

@ricardoV94
Copy link
Member

We are not discarding anything on purpose but an intermediate op (or blockwise) doesn't know how to provide more precise output shape.

There can also be a tradeoff where quite some effort may be needed to figure out static shape that may not be worth it at define time. Anyway the main point is that it shouldn't be a blocker.

We can open an issue for whatever Op is losing the static shape and then assess if it's worth the cost or not

@ricardoV94
Copy link
Member

@jessegrabowski I think I fixed the _delta, it was missing the equivalent call to broadcast_in_dim

@ricardoV94 ricardoV94 force-pushed the einsum branch 2 times, most recently from 70cc6a3 to a902974 Compare August 4, 2024 14:39
Co-authored-by: Adrian Seyboldt <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Rob Zinkov <[email protected]>
@ricardoV94 ricardoV94 merged commit b65d08c into pymc-devs:main Aug 4, 2024
57 checks passed
@zaxtax
Copy link
Contributor

zaxtax commented Aug 4, 2024

🎆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement einsum equivalent
3 participants