From c893f178444aa207f5d4b343aa704c93b3c1e132 Mon Sep 17 00:00:00 2001 From: Spencer Comin Date: Mon, 27 May 2024 22:46:08 -0400 Subject: [PATCH 1/4] [frontend] Add support to AG for Jax single array operator assignment --- frontend/catalyst/autograph/ag_primitives.py | 130 ++++++++++++++++++ .../catalyst/autograph/operator_update.py | 74 ++++++++++ frontend/catalyst/autograph/transformer.py | 17 ++- 3 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 frontend/catalyst/autograph/operator_update.py diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index d132641ce4..fe17303fbc 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -47,6 +47,11 @@ "or_", "not_", "set_item", + "update_item_with_add", + "update_item_with_sub", + "update_item_with_mult", + "update_item_with_div", + "update_item_with_pow", ] @@ -600,6 +605,131 @@ def set_item(target, i, x): return target +def update_item_with_add(target, i, x): + """An implementation of the 'update_item_with_add' function from operator_update. The interface + is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an + implementation in terms of Catalyst primitives. The idea is to accept the simpler single index + operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the + set of 'at' and 'add' calls that Autograph supports. E.g.: + target[i] += x -> target = target.at[i].add(x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...add' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].add(x) + else: + target[i] += x + + return target + + +def update_item_with_sub(target, i, x): + """An implementation of the 'update_item_with_sub' function from operator_update. The interface + is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an + implementation in terms of Catalyst primitives. The idea is to accept the simpler single index + operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the + set of 'at' and 'add' calls that Autograph supports. E.g.: + target[i] -= x -> target = target.at[i].add(-x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...add' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].add(-x) + else: + target[i] -= x + + return target + + +def update_item_with_mult(target, i, x): + """An implementation of the 'update_item_with_mult' function from operator_update. The interface + is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an + implementation in terms of Catalyst primitives. The idea is to accept the simpler single index + operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the + set of 'at' and 'multiply' calls that Autograph supports. E.g.: + target[i] *= x -> target = target.at[i].multiply(x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...multiply' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].multiply(x) + else: + target[i] *= x + + return target + + +def update_item_with_div(target, i, x): + """An implementation of the 'update_item_with_div' function from operator_update. The interface + is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an + implementation in terms of Catalyst primitives. The idea is to accept the simpler single index + operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the + set of 'at' and 'divide' calls that Autograph supports. E.g.: + target[i] /= x -> target = target.at[i].divide(x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...divide' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].divide(x) + else: + target[i] /= x + + return target + + +def update_item_with_pow(target, i, x): + """An implementation of the 'update_item_with_pow' function from operator_update. The interface + is defined in operator_update.SingleIndexArrayOperatorUpdateTransformer, here we provide an + implementation in terms of Catalyst primitives. The idea is to accept the simpler single index + operator assignment syntax for Jax arrays, to subsequently transform it under the hood into the + set of 'at' and 'power' calls that Autograph supports. E.g.: + target[i] **= x -> target = target.at[i].power(x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...power' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].power(x) + else: + target[i] **= x + + return target + + class CRange: """Catalyst range object. diff --git a/frontend/catalyst/autograph/operator_update.py b/frontend/catalyst/autograph/operator_update.py new file mode 100644 index 0000000000..114b60a149 --- /dev/null +++ b/frontend/catalyst/autograph/operator_update.py @@ -0,0 +1,74 @@ +# Copyright 2024 Spencer Comin + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Converter for array element operator assignment.""" + +import gast +from malt.core import converter +from malt.pyct import templates + + +# The methods from this class should be migrated to the SliceTransformer class in DiastaticMalt +class SingleIndexArrayOperatorUpdateTransformer(converter.Base): + """Converts array element operator assignment statements into calls to update_item_with_{op}, + where op is one of the following: + + - `add` corresponding to `+=` + - `sub` to `-=` + - `mult` to `*=` + - `div` to `/=` + - `pow` to `**=` + """ + + def _process_single_update(self, target, op, value): + if not isinstance(target, gast.Subscript): + return None + s = target.slice + if isinstance(s, (gast.Tuple, gast.Slice)): + return None + if not isinstance(op, (gast.Mult, gast.Add, gast.Sub, gast.Div, gast.Pow)): + return None + + template = f""" + target = ag__.update_item_with_{type(op).__name__.lower()}(target, i, x) + """ + + return templates.replace(template, target=target.value, i=target.slice, x=value) + + def visit_AugAssign(self, node): + """The AugAssign node is replaced with a call to ag__.update_item_with_{op} + when its target is a single index array subscript and its op is an arithmetic + operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is. + + Example: + `x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)` + `x[i] ^= y` remains unchanged + """ + node = self.generic_visit(node) + replacement = self._process_single_update(node.target, node.op, node.value) + if replacement is not None: + return replacement + return node + + +def transform(node, ctx): + """Replace an AugAssign node with a call to ag__.update_item_with_{op} + when the its target is a single index array subscript and its op is an arithmetic + operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is. + + Example: + `x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)` + `x[i] ^= y` remains unchanged + """ + return SingleIndexArrayOperatorUpdateTransformer(ctx).visit(node) diff --git a/frontend/catalyst/autograph/transformer.py b/frontend/catalyst/autograph/transformer.py index c5c3324af1..a7ada7f16d 100644 --- a/frontend/catalyst/autograph/transformer.py +++ b/frontend/catalyst/autograph/transformer.py @@ -29,7 +29,7 @@ from malt.impl.api import PyToPy import catalyst -from catalyst.autograph import ag_primitives +from catalyst.autograph import ag_primitives, operator_update from catalyst.utils.exceptions import AutoGraphError @@ -111,6 +111,21 @@ def get_cached_function(self, fn): return new_fn + def transform_ast(self, node, ctx): + """Overload of PyToPy.transform_ast from DiastaticMalt + + .. note:: + Once the operator_update interface has been migrated to the + DiastaticMalt project, this overload can be deleted.""" + # The operator_update transform would be more correct if placed with + # slices.transform in PyToPy.transform_ast in DiastaticMalt rather than + # at the beginning of the transformation. operator_update.transform + # should come after the unsupported features check and intial analysis, + # but it fails if it does not come before variables.transform. + node = operator_update.transform(node, ctx) + node = super().transform_ast(node, ctx) + return node + def run_autograph(fn): """Decorator that converts the given function into graph form.""" From ded9e81dec3cfd9bba59591a6e7e34bc5ab8bc9f Mon Sep 17 00:00:00 2001 From: Spencer Comin Date: Mon, 27 May 2024 22:49:43 -0400 Subject: [PATCH 2/4] [frontend] Add tests --- frontend/test/pytest/test_autograph.py | 245 +++++++++++++++++++++++++ 1 file changed, 245 insertions(+) diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index 041d115bc5..22c730ebb4 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -1883,5 +1883,250 @@ def zero_last_element_python_array(x): ) +class TestJaxIndexOperatorUpdate: + """Test Jax index operator update""" + + def test_single_static_index_operator_update_one_item(self): + """Test single index operator update for Jax arrays for one array item.""" + + @qjit(autograph=True) + def double_first_element_single_operator_assignment_syntax(x): + """Double the first element of x using single index assignment""" + + x[0] *= 2 + return x + + @qjit(autograph=True) + def double_first_element_at_multiply_syntax(x): + """Double the first element of x using at and multiply""" + + x = x.at[0].multiply(2) + return x + + result_assignment_syntax = double_first_element_single_operator_assignment_syntax( + jnp.array([5, 3, 4]) + ) + + assert jnp.allclose(result_assignment_syntax, jnp.array([10, 3, 4])) + assert jnp.allclose( + result_assignment_syntax, + double_first_element_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_operator_update_one_item(self): + """Test single index operator update for Jax arrays for one array item.""" + + @qjit(autograph=True) + def double_last_element_single_operator_assignment_syntax(x): + """Double the last element of x using single index assignment""" + + last_element = x.shape[0] - 1 + x[last_element] *= 2 + return x + + @qjit(autograph=True) + def double_last_element_at_multiply_syntax(x): + """Double the last element of x using at and multiply""" + + last_element = x.shape[0] - 1 + x = x.at[last_element].multiply(2) + return x + + result_assignment_syntax = double_last_element_single_operator_assignment_syntax( + jnp.array([5, 3, 4]) + ) + + assert jnp.allclose(result_assignment_syntax, jnp.array([5, 3, 8])) + assert jnp.allclose( + result_assignment_syntax, + double_last_element_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_mult_update_all_items(self): + """Test single index mult update for Jax arrays for all array items.""" + + @qjit(autograph=True) + def double_all_operator_update_syntax(x): + """Create a new array that is equal to 2 * x using single index mult update""" + + first_dim = x.shape[0] + + for i in range(first_dim): + x[i] *= 2 + + return x + + @qjit(autograph=True) + def double_all_at_multiply_syntax(x): + """Create a new array that is equal to 2 * x using at and multiply""" + + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].multiply(2) + + return result + + result_assignment_syntax = double_all_operator_update_syntax(jnp.array([5, 3, 4])) + + assert jnp.allclose(result_assignment_syntax, jnp.array([10, 6, 8])) + assert jnp.allclose( + result_assignment_syntax, + double_all_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_add_update_all_items(self): + """Test single index add update for Jax arrays for all array items.""" + + @qjit(autograph=True) + def inc_all_operator_update_syntax(x): + """Create a new array that is equal to x + 1 using single index add update""" + + first_dim = x.shape[0] + + for i in range(first_dim): + x[i] += 1 + + return x + + @qjit(autograph=True) + def inc_all_at_multiply_syntax(x): + """Create a new array that is equal to x + 1 using at and multiply""" + + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].add(1) + + return result + + result_assignment_syntax = inc_all_operator_update_syntax(jnp.array([5, 3, 4])) + + assert jnp.allclose(result_assignment_syntax, jnp.array([6, 4, 5])) + assert jnp.allclose( + result_assignment_syntax, + inc_all_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_sub_update_all_items(self): + """Test single index sub update for Jax arrays for all array items.""" + + @qjit(autograph=True) + def dec_all_operator_update_syntax(x): + """Create a new array that is equal to x - 1 using single index sub update""" + + first_dim = x.shape[0] + + for i in range(first_dim): + x[i] -= 1 + + return x + + @qjit(autograph=True) + def dec_all_at_multiply_syntax(x): + """Create a new array that is equal to x - 1 using at and add""" + + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].add(-1) + + return result + + result_assignment_syntax = dec_all_operator_update_syntax(jnp.array([5, 3, 4])) + + assert jnp.allclose(result_assignment_syntax, jnp.array([4, 2, 3])) + assert jnp.allclose( + result_assignment_syntax, + dec_all_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_div_update_all_items(self): + """Test single index div update for Jax arrays for all array items.""" + + @qjit(autograph=True) + def half_all_operator_update_syntax(x): + """Create a new array that is equal to x / 2 using single index div update""" + + first_dim = x.shape[0] + + for i in range(first_dim): + x[i] /= 2 + + return x + + @qjit(autograph=True) + def half_all_at_multiply_syntax(x): + """Create a new array that is equal to x / 2 using at and divide""" + + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].divide(2) + + return result + + result_assignment_syntax = half_all_operator_update_syntax(jnp.array([5, 3, 4])) + + assert jnp.allclose(result_assignment_syntax, jnp.array([2.5, 1.5, 2])) + assert jnp.allclose( + result_assignment_syntax, + half_all_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_pow_update_all_items(self): + """Test single index pow update for Jax arrays for all array items.""" + + @qjit(autograph=True) + def square_all_operator_update_syntax(x): + """Create a new array that is equal to x ** 2 using single index sub update""" + + first_dim = x.shape[0] + + for i in range(first_dim): + x[i] **= 2 + + return x + + @qjit(autograph=True) + def square_all_at_multiply_syntax(x): + """Create a new array that is equal to x ** 2 using at and pow""" + + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].power(2) + + return result + + result_assignment_syntax = square_all_operator_update_syntax(jnp.array([5, 3, 4])) + + assert jnp.allclose(result_assignment_syntax, jnp.array([25, 9, 16])) + assert jnp.allclose( + result_assignment_syntax, + square_all_at_multiply_syntax(jnp.array([5, 3, 4])), + ) + + def test_single_index_operator_update_python_array(self): + """Test single index operator update for Non-Jax arrays for one array item.""" + + @qjit(autograph=True) + def double_last_element_python_array(x): + """Double the last element of a python array""" + + last_element = len(x) - 1 + x[last_element] *= 2 + return x + + assert jnp.allclose( + jnp.array(double_last_element_python_array([5, 3, 4])), jnp.array([5, 3, 8]) + ) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 6add6c5cf7c239ddc55969b38a75358f2aa85414 Mon Sep 17 00:00:00 2001 From: Spencer Comin Date: Mon, 27 May 2024 22:50:57 -0400 Subject: [PATCH 3/4] [NFC] Update documentation --- doc/changelog.md | 30 +++++++++++++++++++++++++++--- doc/dev/autograph.rst | 34 ++++++++++++++++++++++++++++++++-- frontend/catalyst/jit.py | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/doc/changelog.md b/doc/changelog.md index aa16d39613..bb5ccc3f43 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -71,7 +71,7 @@ [(#725)](https://github.com/PennyLaneAI/catalyst/pull/725) Although library code is not meant to be targeted by Autograph conversion, - it sometimes make sense to enable it for specific submodules that might + it sometimes make sense to enable it for specific submodules that might benefit from such conversion: ```py @@ -81,6 +81,30 @@ ``` +* Support for usage of single index JAX array operator update + inside Autograph annotated functions. + [(#769)](https://github.com/PennyLaneAI/catalyst/pull/769) + + Using operator assignment syntax in favor of at...operation expressions is now possible for the following operations: + * `x[i] += y` in favor of `x.at[i].add(y)` + * `x[i] -= y` in favor of `x.at[i].add(-y)` + * `x[i] *= y` in favor of `x.at[i].multiply(y)` + * `x[i] /= y` in favor of `x.at[i].divide(y)` + * `x[i] **= y` in favor of `x.at[i].power(y)` + + ```py + @qjit(autograph=True) + def f(x): + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result[i] *= 2 # This is now supported + + return result + + ``` +

Improvements

* Catalyst now has support for `qml.sample(m)` where `m` is the result of a mid-circuit @@ -103,7 +127,7 @@ [(#751)](https://github.com/PennyLaneAI/catalyst/pull/751) * Refactored `vmap`,`qjit`, `mitigate_with_zne` and gradient decorators in order to follow - a unified pattern that uses a callable class implementing the decorator's logic. This + a unified pattern that uses a callable class implementing the decorator's logic. This prevents having to excessively define functions in a nested fashion. [(#758)](https://github.com/PennyLaneAI/catalyst/pull/758) [(#761)](https://github.com/PennyLaneAI/catalyst/pull/761) @@ -127,7 +151,7 @@ * Correctly linking openblas routines necessary for `jax.scipy.linalg.expm`. In this bug fix, four openblas routines were newly linked and are now discoverable by `stablehlo.custom_call@`. They are `blas_dtrsm`, `blas_ztrsm`, `lapack_dgetrf`, `lapack_zgetrf`. - [(#752)](https://github.com/PennyLaneAI/catalyst/pull/752) + [(#752)](https://github.com/PennyLaneAI/catalyst/pull/752)

Internal changes

diff --git a/doc/dev/autograph.rst b/doc/dev/autograph.rst index 47b43f3a32..c0ac52e75f 100644 --- a/doc/dev/autograph.rst +++ b/doc/dev/autograph.rst @@ -921,8 +921,8 @@ Notice that ``autograph=True`` must be set in order to process the ``autograph_include`` list. Otherwise an error will be reported. -In-place JAX array assignments ------------------------------- +In-place JAX array updates +-------------------------- To update array values when using JAX, the `JAX syntax for array assignment `__ @@ -952,3 +952,33 @@ standard Python array assignment syntax: ... return result Under the hood, Catalyst converts anything coming in the latter notation into the former one. + +Similarly, to update array values with an operation when using JAX, the JAX syntax for array +update (which uses the array `at` and the `add`, `multiply`, etc. methods) must be used: + +>>> @qjit(autograph=True) +... def f(x): +... first_dim = x.shape[0] +... result = jnp.copy(x) +... +... for i in range(first_dim): +... result = result.at[i].multiply(2) +... +... return result + +Again, if updating a single index of the array, Autograph supports conversion of +standard Python array operator assignment syntax for the equivalent in-place expressions +listed in the `JAX documentation for jax.numpy.ndarray.at +`__: + +>>> @qjit(autograph=True) +... def f(x): +... first_dim = x.shape[0] +... result = jnp.copy(x) +... +... for i in range(first_dim): +... result[i] *= 2 +... +... return result + +Under the hood, Catalyst converts anything coming in the latter notation into the former one. diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 9564c915aa..b4b7ee0c92 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -244,7 +244,7 @@ def g(x: int): .. details:: - :title: In-place JAX array assignments with Autograph + :title: In-place JAX array updates with Autograph To update array values when using JAX, the JAX syntax for array assignment (which uses the array ``at`` and ``set`` methods) must be used: @@ -279,6 +279,41 @@ def f(x): Under the hood, Catalyst converts anything coming in the latter notation into the former one. + Similarly, to update array values with an operation when using JAX, the JAX syntax for array + update (which uses the array ``at`` and the ``add``, ``sub``, etc. methods) must be used: + + .. code-block:: python + + @qjit(autograph=True) + def f(x): + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result = result.at[i].multiply(2) + + return result + + Again, if updating a single index of the array, Autograph supports conversion of + standard Python array operator assignment syntax for the equivalent in-place expressions + listed in the JAX documentation for ``jax.numpy.ndarray.at``: + + .. code-block:: python + + @qjit(autograph=True) + def f(x): + first_dim = x.shape[0] + result = jnp.copy(x) + + for i in range(first_dim): + result[i] *= 2 + + return result + + Under the hood, Catalyst converts anything coming in the latter notation into the + former one. + + .. details:: :title: Static arguments From 7981c3c6a5198e6c3127300a54f729725209ebf2 Mon Sep 17 00:00:00 2001 From: Spencer Comin Date: Mon, 12 Aug 2024 16:51:45 -0600 Subject: [PATCH 4/4] Give operator_update.py copyright to Xanadu --- frontend/catalyst/autograph/operator_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/autograph/operator_update.py b/frontend/catalyst/autograph/operator_update.py index 114b60a149..4b31077f23 100644 --- a/frontend/catalyst/autograph/operator_update.py +++ b/frontend/catalyst/autograph/operator_update.py @@ -1,4 +1,4 @@ -# Copyright 2024 Spencer Comin +# Copyright 2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.