-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
- Loading branch information
1 parent
981688c
commit 52575e8
Showing
18 changed files
with
1,361 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import jax.numpy as jnp | ||
|
||
from pytensor.link.jax.dispatch import jax_funcify | ||
from pytensor.tensor.einsum import Einsum | ||
|
||
|
||
@jax_funcify.register(Einsum) | ||
def jax_funcify_Einsum(op, **kwargs): | ||
"""Dispatch einsum to JAX. | ||
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level. | ||
This happens when some of the dimension lengths are unknown. This is never a problem in JAX, | ||
as it always compiles a function per runtime input shape. | ||
""" | ||
subscripts = op.subscripts | ||
|
||
def einsum(*operands): | ||
return jnp.einsum(subscripts, *operands, optimize="optimal") | ||
|
||
return einsum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.