Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to AutoGraph for #769

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
[(#725)](https://github.com/PennyLaneAI/catalyst/pull/725)

Although library code is not meant to be targeted by Autograph conversion,
it sometimes make sense to enable it for specific submodules that might
dime10 marked this conversation as resolved.
Show resolved Hide resolved
it sometimes make sense to enable it for specific submodules that might
benefit from such conversion:

```py
Expand All @@ -81,6 +81,30 @@

```

* Support for usage of single index JAX array operator update
inside Autograph annotated functions.
[(#769)](https://github.com/PennyLaneAI/catalyst/pull/769)

Using operator assignment syntax in favor of at...operation expressions is now possible for the following operations:
* `x[i] += y` in favor of `x.at[i].add(y)`
* `x[i] -= y` in favor of `x.at[i].add(-y)`
* `x[i] *= y` in favor of `x.at[i].multiply(y)`
* `x[i] /= y` in favor of `x.at[i].divide(y)`
* `x[i] **= y` in favor of `x.at[i].power(y)`

```py
@qjit(autograph=True)
def f(x):
first_dim = x.shape[0]
result = jnp.copy(x)

for i in range(first_dim):
result[i] *= 2 # This is now supported

return result

```
Comment on lines +104 to +106
Copy link
Contributor

@dime10 dime10 May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice example! Maybe we could also include the function output to show what it does:

Suggested change
return result
```
return result
>>> f(jnp.array([1, 2, 3]))
Array([2, 4, 6], dtype=int64)
```


<h3>Improvements</h3>

* Catalyst now has support for `qml.sample(m)` where `m` is the result of a mid-circuit
Expand All @@ -103,7 +127,7 @@
[(#751)](https://github.com/PennyLaneAI/catalyst/pull/751)

* Refactored `vmap`,`qjit`, `mitigate_with_zne` and gradient decorators in order to follow
a unified pattern that uses a callable class implementing the decorator's logic. This
a unified pattern that uses a callable class implementing the decorator's logic. This
prevents having to excessively define functions in a nested fashion.
[(#758)](https://github.com/PennyLaneAI/catalyst/pull/758)
[(#761)](https://github.com/PennyLaneAI/catalyst/pull/761)
Expand All @@ -127,7 +151,7 @@

* Correctly linking openblas routines necessary for `jax.scipy.linalg.expm`.
In this bug fix, four openblas routines were newly linked and are now discoverable by `stablehlo.custom_call@<blas_routine>`. They are `blas_dtrsm`, `blas_ztrsm`, `lapack_dgetrf`, `lapack_zgetrf`.
[(#752)](https://github.com/PennyLaneAI/catalyst/pull/752)
[(#752)](https://github.com/PennyLaneAI/catalyst/pull/752)

<h3>Internal changes</h3>

Expand Down
34 changes: 32 additions & 2 deletions doc/dev/autograph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,8 @@ Notice that ``autograph=True`` must be set in order to process the
``autograph_include`` list. Otherwise an error will be reported.


In-place JAX array assignments
------------------------------
In-place JAX array updates
--------------------------

To update array values when using JAX, the `JAX syntax for array assignment
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#array-updates-x-at-idx-set-y>`__
Expand Down Expand Up @@ -952,3 +952,33 @@ standard Python array assignment syntax:
... return result

Under the hood, Catalyst converts anything coming in the latter notation into the former one.

Similarly, to update array values with an operation when using JAX, the JAX syntax for array
update (which uses the array `at` and the `add`, `multiply`, etc. methods) must be used:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result = result.at[i].multiply(2)
...
... return result

Again, if updating a single index of the array, Autograph supports conversion of
standard Python array operator assignment syntax for the equivalent in-place expressions
listed in the `JAX documentation for jax.numpy.ndarray.at
<https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at>`__:

>>> @qjit(autograph=True)
... def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result[i] *= 2
...
... return result

Under the hood, Catalyst converts anything coming in the latter notation into the former one.
130 changes: 130 additions & 0 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
"or_",
"not_",
"set_item",
"update_item_with_add",
"update_item_with_sub",
"update_item_with_mult",
"update_item_with_div",
"update_item_with_pow",
]


Expand Down Expand Up @@ -600,6 +605,131 @@ def set_item(target, i, x):
return target


def update_item_with_add(target, i, x):
Copy link
Contributor

@dime10 dime10 May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be more maintainable to have a single function which dispatches to each operator, say in case we wanted to expand the functionality a bit more, or to have each operator have a separate function. What is your opinion on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try to use a helper to generate the different functions, similar to _logical_op, but I couldn't figure out a nice way to put += and friends in a lambda, so this seemed like the next best option.

A possibly more maintainable alternative would be to have a single update_item_with_op function that takes an extra parameter and uses that at runtime to determine which JAX function / operator to dispatch to, something like

def update_item_with_op(op, target, i, x):
    if op == '+':
        if isinstance(target, DynamicJaxprTracer):
            target = target.at[i].add(x)
        else:
            target[i] += x
    elif op == '*':
        if isinstance(target, DynamicJaxprTracer):
            target = target.at[i].multiply(x)
        else:
            target[i] *= x
    elif op == '/':
        ...

When generating the call in operator_update.transform we would just stick in the appropriate constant. I do see two cons to this approach: first, it relies on a set of magic constants, and second, it adds extra compares and branches in the output that could negatively impact performance. Those may not be big concerns though; the calls are only generated from one place and this function isn't exposed to the user (right?) so magic constants aren't that problematic IMO, and if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches.

Copy link
Contributor

@dime10 dime10 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Spencer, that's a good analysis. The single update_item_with_op is what I had in mind, but your reasoning is sound so I'm happy to leave the functions as they are :)

if the JAX compiler is smart enough with inlining / constant propagation / dead code elimination it will get rid of the extra compares and branches

Something that might be interesting to you is that JAX is not capable of compiling most Python code. Instead, it will compile an instruction if either:

  • it is function from the jax library (like e.g. jax.numpy.abs(x))
  • it is a Python operator acting on a jax data (array) object (like the mult in jax.numpy.array(2) * 3)

Consequently, the comparisons and branches all happen before the JAX compiler kicks in (when Python executes the code). There are however special functions that allow you compile a conditional (if statement) for example.

"""An implementation of the 'update_item_with_add' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept the simpler single index
operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the
set of 'at' and 'add' calls that Autograph supports. E.g.:
target[i] += x -> target = target.at[i].add(x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...add' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].add(x)
else:
target[i] += x

return target


def update_item_with_sub(target, i, x):
"""An implementation of the 'update_item_with_sub' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept the simpler single index
operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the
set of 'at' and 'add' calls that Autograph supports. E.g.:
target[i] -= x -> target = target.at[i].add(-x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...add' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].add(-x)
else:
target[i] -= x

return target


def update_item_with_mult(target, i, x):
"""An implementation of the 'update_item_with_mult' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept the simpler single index
operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the
set of 'at' and 'multiply' calls that Autograph supports. E.g.:
target[i] *= x -> target = target.at[i].multiply(x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...multiply' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].multiply(x)
else:
target[i] *= x

return target


def update_item_with_div(target, i, x):
"""An implementation of the 'update_item_with_div' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept the simpler single index
operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the
set of 'at' and 'divide' calls that Autograph supports. E.g.:
target[i] /= x -> target = target.at[i].divide(x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...divide' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].divide(x)
else:
target[i] /= x

return target


def update_item_with_pow(target, i, x):
"""An implementation of the 'update_item_with_pow' function from operator_update. The interface
is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an
implementation in terms of Catalyst primitives. The idea is to accept the simpler single index
operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the
set of 'at' and 'power' calls that Autograph supports. E.g.:
target[i] **= x -> target = target.at[i].power(x)

.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...power' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].power(x)
else:
target[i] **= x

return target


class CRange:
"""Catalyst range object.

Expand Down
74 changes: 74 additions & 0 deletions frontend/catalyst/autograph/operator_update.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great 👍

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Converter for array element operator assignment."""

import gast
from malt.core import converter
from malt.pyct import templates


# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt
# TODO: The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt

class SingleIndexArrayOperatorUpdateTransformer(converter.Base):
"""Converts array element operator assignment statements into calls to update_item_with_{op},
where op is one of the following:

- `add` corresponding to `+=`
- `sub` to `-=`
- `mult` to `*=`
- `div` to `/=`
- `pow` to `**=`
"""

def _process_single_update(self, target, op, value):
if not isinstance(target, gast.Subscript):
return None
s = target.slice
if isinstance(s, (gast.Tuple, gast.Slice)):
return None
if not isinstance(op, (gast.Mult, gast.Add, gast.Sub, gast.Div, gast.Pow)):
return None

template = f"""
target = ag__.update_item_with_{type(op).__name__.lower()}(target, i, x)
"""

return templates.replace(template, target=target.value, i=target.slice, x=value)

def visit_AugAssign(self, node):
"""The AugAssign node is replaced with a call to ag__.update_item_with_{op}
when its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.

Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
node = self.generic_visit(node)
replacement = self._process_single_update(node.target, node.op, node.value)
if replacement is not None:
return replacement
return node


def transform(node, ctx):
"""Replace an AugAssign node with a call to ag__.update_item_with_{op}
when the its target is a single index array subscript and its op is an arithmetic
operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is.

Example:
`x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)`
`x[i] ^= y` remains unchanged
"""
return SingleIndexArrayOperatorUpdateTransformer(ctx).visit(node)
17 changes: 16 additions & 1 deletion frontend/catalyst/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from malt.impl.api import PyToPy

import catalyst
from catalyst.autograph import ag_primitives
from catalyst.autograph import ag_primitives, operator_update
from catalyst.utils.exceptions import AutoGraphError


Expand Down Expand Up @@ -111,6 +111,21 @@ def get_cached_function(self, fn):

return new_fn

def transform_ast(self, node, ctx):
"""Overload of PyToPy.transform_ast from DiastaticMalt

.. note::
Once the operator_update interface has been migrated to the
DiastaticMalt project, this overload can be deleted."""
# The operator_update transform would be more correct if placed with
# slices.transform in PyToPy.transform_ast in DiastaticMalt rather than
# at the beginning of the transformation. operator_update.transform
# should come after the unsupported features check and intial analysis,
# but it fails if it does not come before variables.transform.
node = operator_update.transform(node, ctx)
node = super().transform_ast(node, ctx)
return node


def run_autograph(fn):
"""Decorator that converts the given function into graph form."""
Expand Down
Loading