Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to AutoGraph for #769

Closed
wants to merge 4 commits into from

Conversation

Spencer-Comin
Copy link
Contributor

Context: #717 added support for converting in-place array updates (arr[i] = x) into the equivalent JAX traceable code (arr.at[i].set(x)). This change extends that support to operator assignment array updates.

Description of the Change:

  • Add new Autograph converter to map AugAssign ast nodes assigning to a single index subscript to calls to update_item_with_{add|sub|mult|div|pow}
  • Implement update_item_with_{add|sub|mult|div|pow} methods that map to the corresponding jax.numpy.ndarray.at equivalent methods for JAX arrays and the normal Python operator assignment otherwise
  • Overload transform_ast in CatalystTransformer to invoke the new converter

Benefits: We can use arr[i] += x instead of arr.at[i].add(x).

Possible Drawbacks: It would be cleaner to have the new converter live in the DiastaticMalt project.

Related GitHub Issues: #757

Based on the solution presented in this PR: #717

@dime10
Copy link
Contributor

dime10 commented May 28, 2024

Hi @Spencer-Comin, welcome to the Catalyst project and thank you for opening the PR! :) I'll give it a review it shortly, and also approve the CI runs. If you have any questions or discussion points about the PR feel free to leave them here.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, this looks great 🚀 🚀

Comment on lines +104 to +106
return result

```
Copy link
Contributor

@dime10 dime10 May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice example! Maybe we could also include the function output to show what it does:

Suggested change
return result
```
return result
>>> f(jnp.array([1, 2, 3]))
Array([2, 4, 6], dtype=int64)
```

from malt.pyct import templates


# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt
# TODO: The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great 👍

@@ -279,6 +279,41 @@ def f(x):
Under the hood, Catalyst converts anything coming in the latter notation into the
former one.

Similarly, to update array values with an operation when using JAX, the JAX syntax for array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to also update the qjit docstring, although it is getting quite long. Maybe you can find a way to combine this example with one above to save some space?

@@ -600,6 +605,131 @@ def set_item(target, i, x):
return target


def update_item_with_add(target, i, x):
Copy link
Contributor

@dime10 dime10 May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be more maintainable to have a single function which dispatches to each operator, say in case we wanted to expand the functionality a bit more, or to have each operator have a separate function. What is your opinion on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try to use a helper to generate the different functions, similar to _logical_op, but I couldn't figure out a nice way to put += and friends in a lambda, so this seemed like the next best option.

A possibly more maintainable alternative would be to have a single update_item_with_op function that takes an extra parameter and uses that at runtime to determine which JAX function / operator to dispatch to, something like

def update_item_with_op(op, target, i, x):
    if op == '+':
        if isinstance(target, DynamicJaxprTracer):
            target = target.at[i].add(x)
        else:
            target[i] += x
    elif op == '*':
        if isinstance(target, DynamicJaxprTracer):
            target = target.at[i].multiply(x)
        else:
            target[i] *= x
    elif op == '/':
        ...

When generating the call in operator_update.transform we would just stick in the appropriate constant. I do see two cons to this approach: first, it relies on a set of magic constants, and second, it adds extra compares and branches in the output that could negatively impact performance. Those may not be big concerns though; the calls are only generated from one place and this function isn't exposed to the user (right?) so magic constants aren't that problematic IMO, and if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches.

Copy link
Contributor

@dime10 dime10 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Spencer, that's a good analysis. The single update_item_with_op is what I had in mind, but your reasoning is sound so I'm happy to leave the functions as they are :)

if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches

Something that might be interesting to you is that JAX is not capable of compiling most Python code. Instead, it will compile an instruction if either:

  • it is function from the jax library (like e.g. jax.numpy.abs(x))
  • it is a Python operator acting on a jax data (array) object (like the mult in jax.numpy.array(2) * 3)

Consequently, the comparisons and branches all happen before the JAX compiler kicks in (when Python executes the code). There are however special functions that allow you compile a conditional (if statement) for example.

mehrdad2m added a commit that referenced this pull request Sep 20, 2024
…es (#1143)

**Context:** 

#717 added support for
converting in-place array updates (`arr[i] = x`) into the equivalent JAX
traceable code (`arr.at[i].set(x)`). This PR extends that support to
operator assignment array updates.

**Description of the Change:**

- Add new Autograph converter to map `AugAssign` ast nodes assigning to
a single index or a slice subscript to calls to `update_item_with_op`
- Implement `update_item_with_op` method that map to the corresponding
`jax.numpy.ndarray.at` equivalent methods for JAX arrays and the normal
Python operator assignment otherwise
- Overload `transform_ast` in `CatalystTransformer` to invoke the new
converter

**Benefits:** We can use `arr[i] += x` instead of `arr.at[i].add(x)`.

**Possible Drawbacks:** It would be cleaner to have the new converter
live in the DiastaticMalt project.

**Related GitHub Issues:**
#757

**Based on the solution presented in this PR:**
#769
Note that this PR was originally implemented externally by
#769. This PR aims to
revisit that PR.

---------

Co-authored-by: Spencer Comin <[email protected]>
@erick-xanadu
Copy link
Contributor

Superseded by #1143

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants