Skip to content

Commit

Permalink
Create a frontend UI for users to specify quantum compilation pipelin…
Browse files Browse the repository at this point in the history
…es (#1131)

**Context:** 
We wish to create a frontend `pipeline` function for users to specify
what circuit transformation passes to run in the dictionary form:
https://app.shortcut.com/xanaduai/epic/67492/p1-a-ui-is-available-for-users-to-specify-quantum-compilation-pipelines-from-python


**Description of the Change:**
A new API decorator function `catalyst.passes.pipeline` can be applied
to qnodes. The decorator takes in a dictionary specifying what circuit
transformation passes to run on the qnode with what options, and adds
the passes to the `transform_named_sequence`.

In addition, `qjit` now takes in a kwarg `circuit_transform_pipeline`
which expects the same pipeline dictionary. This will apply the pipeline
to all qnodes in the qjit.

To test the pipeline, we added the boilerplate for a merge_rotation pass
with a pass option. The pass is currently empty and does nothing. The
merge_rotation boilerplate pass is not exposed as a user-facing API, and
only exists for the purpose of testing the pipeline.


**Benefits:**
User can specify pass pipelines.


**Possible Drawbacks:**
There are two items of improvements possible:
1. The target qnode of the pass is recorded by name. This is not
optimal. The quantum scope work will likely put each qnode into a module
instead of a `func.func ... attributes {qnode}` in mlir. When that is in place,
the qnode's module can have a proper attribute (as opposed to
discardable) that records its transform schedule, i.e.
```
module_with_transform @name_of_module {
  // transform schedule
} {
  // contents of the module
}
```
This eliminates the need for matching target functions by name.

2. The number of `qjit` kwargs is maybe too many. As of now we implement
the syntax specified by the epic, but there could be an alternate
design.


[sc-67520]
  • Loading branch information
paul0403 authored Sep 19, 2024
1 parent b3810be commit 5cf6bd4
Show file tree
Hide file tree
Showing 12 changed files with 664 additions and 32 deletions.
61 changes: 61 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@
Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64))
```

* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline for QNodes within a qjit-compiled workflow to be configured.
[(#1131)](https://github.com/PennyLaneAI/catalyst/pull/1131)

```python
my_passes = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
dev = qml.device("lightning.qubit", wires=2)

@pipeline(my_passes)
@qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

@qjit
def fn(x):
return jnp.sin(circuit(x ** 2))
```

`pipeline` can also be used to specify different pass pipelines for different parts of the
same qjit-compiled workflow:

```python
my_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}

my_other_pipeline = {"cancel_inverses": {}}

@qjit
def fn(x):
circuit_pipeline = pipeline(my_pipeline)(circuit)
circuit_other = pipeline(my_other_pipeline)(circuit)
return jnp.abs(circuit_pipeline(x) - circuit_other(x))
```

For a list of available passes, please see the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

The pass pipeline order and options can be configured *globally* for a
qjit-compiled function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit` decorator.

```python
my_passes = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}

@qjit(circuit_transform_pipeline=my_passes)
def fn(x):
return jnp.sin(circuit(x ** 2))
```

Global and local (via `@pipeline`) configurations can coexist, however local pass pipelines
will always take precedence over global pass pipelines.

Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

<h3>Improvements</h3>

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`,
Expand Down
5 changes: 5 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class CompileOptions:
experimental_capture (bool): If set to ``True``,
use PennyLane's experimental program capture capabilities
to capture the function for compilation.
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
A dictionary that specifies the quantum circuit transformation pass pipeline order,
and optionally arguments for each pass in the pipeline.
Default is None.
"""

verbose: Optional[bool] = False
Expand All @@ -94,6 +98,7 @@ class CompileOptions:
disable_assertions: Optional[bool] = False
seed: Optional[int] = None
experimental_capture: Optional[bool] = False
circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None

def __post_init__(self):
# Check that async runs must not be seeded
Expand Down
20 changes: 18 additions & 2 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from catalyst.from_plxpr import trace_from_pennylane
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
from catalyst.passes import PipelineNameUniquer, _inject_transform_named_sequence
from catalyst.qfunc import QFunc
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
Expand Down Expand Up @@ -85,6 +85,7 @@ def qjit(
disable_assertions=False,
seed=None,
experimental_capture=False,
circuit_transform_pipeline=None,
): # pylint: disable=too-many-arguments,unused-argument
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.
Expand Down Expand Up @@ -141,6 +142,15 @@ def qjit(
experimental_capture (bool): If set to ``True``, the qjit decorator
will use PennyLane's experimental program capture capabilities
to capture the decorated function for compilation.
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
A dictionary that specifies the quantum circuit transformation pass pipeline order,
and optionally arguments for each pass in the pipeline. Keys of this dictionary
should correspond to names of passes found in the `catalyst.passes <https://docs.
pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes>`_
module, values should either be empty dictionaries (for default pass options) or
dictionaries of valid keyword arguments and values for the specific pass.
The order of keys in this dictionary will determine the pass pipeline.
If not specified, the default pass pipeline will be applied.
Returns:
QJIT object.
Expand Down Expand Up @@ -609,7 +619,12 @@ def closure(qnode, *args, **kwargs):
params = {}
params["static_argnums"] = kwargs.pop("static_argnums", static_argnums)
params["_out_tree_expected"] = []
return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
return QFunc.__call__(
qnode,
pass_pipeline=self.compile_options.circuit_transform_pipeline,
*args,
**dict(params, **kwargs),
)

with Patcher(
(qml.QNode, "__call__", closure),
Expand All @@ -623,6 +638,7 @@ def closure(qnode, *args, **kwargs):
kwargs,
)

PipelineNameUniquer.reset()
return jaxpr, out_type, treedef, dynamic_sig

@instrument(size_from=0, has_finegrained=True)
Expand Down
188 changes: 173 additions & 15 deletions frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,155 @@
"""

import copy
import functools
from typing import Optional

import pennylane as qml

from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p
from catalyst.tracing.contexts import EvaluationContext


## API ##
# pylint: disable=line-too-long
def pipeline(fn=None, *, pass_pipeline: Optional[dict[str, dict[str, str]]] = None):
"""Configures the Catalyst MLIR pass pipeline for quantum circuit transformations for a QNode within a qjit-compiled program.
Args:
fn (QNode): The QNode to run the pass pipeline on.
pass_pipeline (dict[str, dict[str, str]]): A dictionary that specifies the pass pipeline order, and optionally
arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes
found in the `catalyst.passes <https://docs.pennylane.ai/projects/catalyst/en/stable/code
/__init__.html#module-catalyst.passes>`_ module, values should either be empty dictionaries
(for default pass options) or dictionaries of valid keyword arguments and values for the specific pass.
The order of keys in this dictionary will determine the pass pipeline.
If not specified, the default pass pipeline will be applied.
Returns:
~.QNode:
For a list of available passes, please see the :doc:`catalyst.passes module <code/passes>`.
The default pass pipeline when used with Catalyst is currently empty.
**Example**
``pipeline`` can be used to configure the pass pipeline order and options
of a QNode within a qjit-compiled function.
Configuration options are passed to specific passes via dictionaries:
.. code-block:: python
my_pass_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
@pipeline(my_pass_pipeline)
@qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
@qjit
def fn(x):
return jnp.sin(circuit(x ** 2))
``pipeline`` can also be used to specify different pass pipelines for different parts of the
same qjit-compiled workflow:
.. code-block:: python
my_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
my_other_pipeline = {"cancel_inverses": {}}
@qjit
def fn(x):
circuit_pipeline = pipeline(my_pipeline)(circuit)
circuit_other = pipeline(my_other_pipeline)(circuit)
return jnp.abs(circuit_pipeline(x) - circuit_other(x))
.. note::
As of Python 3.7, the CPython dictionary implementation orders dictionaries based on
insertion order. However, for an API gaurantee of dictionary order, ``collections.OrderedDict``
may also be used.
Note that the pass pipeline order and options can be configured *globally* for a
qjit-compiled function, by using the ``circuit_transform_pipeline`` argument of
the :func:`~.qjit` decorator.
.. code-block:: python
my_pass_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
@qjit(circuit_transform_pipeline=my_pass_pipeline)
def fn(x):
return jnp.sin(circuit(x ** 2))
Global and local (via ``@pipeline``) configurations can coexist, however local pass pipelines
will always take precedence over global pass pipelines.
"""

kwargs = copy.copy(locals())
kwargs.pop("fn")

if fn is None:
return functools.partial(pipeline, **kwargs)

if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

if pass_pipeline is None:
# TODO: design a default peephole pipeline
return fn

fn_original_name = fn.__name__
wrapped_qnode_function = fn.func
fn_clone = copy.copy(fn)
uniquer = str(_rename_to_unique())
fn_clone.__name__ = fn_original_name + "_transformed" + uniquer

pass_names = _API_name_to_pass_name()

def wrapper(*args, **kwrags):
# TODO: we should not match pass targets by function name.
# The quantum scope work will likely put each qnode into a module
# instead of a `func.func ... attributes {qnode}`.
# When that is in place, the qnode's module can have a proper attribute
# (as opposed to discardable) that records its transform schedule, i.e.
# module_with_transform @name_of_module {
# // transform schedule
# } {
# // contents of the module
# }
# This eliminates the need for matching target functions by name.

if EvaluationContext.is_tracing():
for API_name, pass_options in pass_pipeline.items():
opt = ""
for option, option_value in pass_options.items():
opt += " " + str(option) + "=" + str(option_value)
apply_registered_pass_p.bind(
pass_name=pass_names[API_name],
options=f"func-name={fn_original_name}" + "_transformed" + uniquer + opt,
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone.func = wrapper
fn_clone._peephole_transformed = True # pylint: disable=protected-access

return fn_clone


def cancel_inverses(fn=None):
"""
Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass
Expand Down Expand Up @@ -150,33 +291,50 @@ def circuit(x: float):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func
funcname = fn.__name__
wrapped_qnode_function = fn.func
uniquer = str(_rename_to_unique())

def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses" + uniquer,
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"
fn_clone.__name__ = funcname + "_cancel_inverses" + uniquer

return fn_clone


## IMPL and helpers ##
# pylint: disable=missing-function-docstring
class _PipelineNameUniquer:
def __init__(self, i):
self.i = i

def get(self):
self.i += 1
return self.i

def reset(self):
self.i = -1


PipelineNameUniquer = _PipelineNameUniquer(-1)


def _rename_to_unique():
return PipelineNameUniquer.get()


def _API_name_to_pass_name():
return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"}


def _inject_transform_named_sequence():
"""
Inject a transform_named_sequence jax primitive.
Expand Down
9 changes: 9 additions & 0 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
from catalyst.logging import debug_logger
from catalyst.passes import pipeline
from catalyst.tracing.type_signatures import filter_static_args

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,10 +93,18 @@ def __new__(cls):
raise NotImplementedError() # pragma: no-cover

# pylint: disable=no-member
# pylint: disable=self-cls-assignment
@debug_logger
def __call__(self, *args, **kwargs):
assert isinstance(self, qml.QNode)

# Update the qnode with peephole pipeline
if "pass_pipeline" in kwargs.keys():
pass_pipeline = kwargs["pass_pipeline"]
if not hasattr(self, "_peephole_transformed"):
self = pipeline(pass_pipeline=pass_pipeline)(self)
kwargs.pop("pass_pipeline")

# Mid-circuit measurement configuration/execution
dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False)
if not dynamic_one_shot_called:
Expand Down
Loading

0 comments on commit 5cf6bd4

Please sign in to comment.