-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to AutoGraph for #769
Conversation
4e164d7
to
6add6c5
Compare
Hi @Spencer-Comin, welcome to the Catalyst project and thank you for opening the PR! :) I'll give it a review it shortly, and also approve the CI runs. If you have any questions or discussion points about the PR feel free to leave them here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, this looks great 🚀 🚀
return result | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice example! Maybe we could also include the function output to show what it does:
return result | |
``` | |
return result | |
>>> f(jnp.array([1, 2, 3])) | |
Array([2, 4, 6], dtype=int64) | |
``` |
from malt.pyct import templates | ||
|
||
|
||
# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt | |
# TODO: The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great 👍
@@ -279,6 +279,41 @@ def f(x): | |||
Under the hood, Catalyst converts anything coming in the latter notation into the | |||
former one. | |||
|
|||
Similarly, to update array values with an operation when using JAX, the JAX syntax for array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to also update the qjit
docstring, although it is getting quite long. Maybe you can find a way to combine this example with one above to save some space?
@@ -600,6 +605,131 @@ def set_item(target, i, x): | |||
return target | |||
|
|||
|
|||
def update_item_with_add(target, i, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be more maintainable to have a single function which dispatches to each operator, say in case we wanted to expand the functionality a bit more, or to have each operator have a separate function. What is your opinion on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try to use a helper to generate the different functions, similar to _logical_op
, but I couldn't figure out a nice way to put +=
and friends in a lambda, so this seemed like the next best option.
A possibly more maintainable alternative would be to have a single update_item_with_op
function that takes an extra parameter and uses that at runtime to determine which JAX function / operator to dispatch to, something like
def update_item_with_op(op, target, i, x):
if op == '+':
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].add(x)
else:
target[i] += x
elif op == '*':
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].multiply(x)
else:
target[i] *= x
elif op == '/':
...
When generating the call in operator_update.transform
we would just stick in the appropriate constant. I do see two cons to this approach: first, it relies on a set of magic constants, and second, it adds extra compares and branches in the output that could negatively impact performance. Those may not be big concerns though; the calls are only generated from one place and this function isn't exposed to the user (right?) so magic constants aren't that problematic IMO, and if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Spencer, that's a good analysis. The single update_item_with_op
is what I had in mind, but your reasoning is sound so I'm happy to leave the functions as they are :)
if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches
Something that might be interesting to you is that JAX is not capable of compiling most Python code. Instead, it will compile an instruction if either:
- it is function from the jax library (like e.g.
jax.numpy.abs(x)
) - it is a Python operator acting on a jax data (array) object (like the mult in
jax.numpy.array(2) * 3
)
Consequently, the comparisons and branches all happen before the JAX compiler kicks in (when Python executes the code). There are however special functions that allow you compile a conditional (if statement) for example.
…es (#1143) **Context:** #717 added support for converting in-place array updates (`arr[i] = x`) into the equivalent JAX traceable code (`arr.at[i].set(x)`). This PR extends that support to operator assignment array updates. **Description of the Change:** - Add new Autograph converter to map `AugAssign` ast nodes assigning to a single index or a slice subscript to calls to `update_item_with_op` - Implement `update_item_with_op` method that map to the corresponding `jax.numpy.ndarray.at` equivalent methods for JAX arrays and the normal Python operator assignment otherwise - Overload `transform_ast` in `CatalystTransformer` to invoke the new converter **Benefits:** We can use `arr[i] += x` instead of `arr.at[i].add(x)`. **Possible Drawbacks:** It would be cleaner to have the new converter live in the DiastaticMalt project. **Related GitHub Issues:** #757 **Based on the solution presented in this PR:** #769 Note that this PR was originally implemented externally by #769. This PR aims to revisit that PR. --------- Co-authored-by: Spencer Comin <[email protected]>
Superseded by #1143 |
Context: #717 added support for converting in-place array updates (
arr[i] = x
) into the equivalent JAX traceable code (arr.at[i].set(x)
). This change extends that support to operator assignment array updates.Description of the Change:
AugAssign
ast nodes assigning to a single index subscript to calls toupdate_item_with_{add|sub|mult|div|pow}
update_item_with_{add|sub|mult|div|pow}
methods that map to the correspondingjax.numpy.ndarray.at
equivalent methods for JAX arrays and the normal Python operator assignment otherwisetransform_ast
inCatalystTransformer
to invoke the new converterBenefits: We can use
arr[i] += x
instead ofarr.at[i].add(x)
.Possible Drawbacks: It would be cleaner to have the new converter live in the DiastaticMalt project.
Related GitHub Issues: #757
Based on the solution presented in this PR: #717