Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a frontend UI for users to specify quantum compilation pipelines #1131

Merged
merged 45 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
69c7933
pulling ZNE mitigation (`-lower-mitigation`) into the transform_named…
paul0403 Sep 11, 2024
6974981
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
4761364
update passes.mitigate_with_zne test to include new `scale_factors`
paul0403 Sep 11, 2024
b5ec1b5
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
13899c1
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 12, 2024
2a38c02
init the local pipeline decorator
paul0403 Sep 12, 2024
986cc94
lit tests
paul0403 Sep 13, 2024
d7b98ae
add global peephole pipeline option in qjit
paul0403 Sep 13, 2024
00b28d7
format
paul0403 Sep 13, 2024
4efb372
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 16, 2024
300c61c
reverting zne changes; this PR will leave zne untouched
paul0403 Sep 16, 2024
30b9f41
create merge rotation pass boilerplate
paul0403 Sep 16, 2024
99a3f13
reverting zne to main
paul0403 Sep 16, 2024
c23f52b
rewriting tests to exclude zne
paul0403 Sep 16, 2024
6f5dd59
put back lower-mitigation in default pipeline
paul0403 Sep 16, 2024
7dc40cd
codefactor
paul0403 Sep 16, 2024
1be95b5
local pipelines override global pipelines
paul0403 Sep 16, 2024
50f2585
make sure cudajit (which will never have the pass_pipeline) does not …
paul0403 Sep 16, 2024
d0d29c2
format
paul0403 Sep 16, 2024
60769e3
codefactor
paul0403 Sep 17, 2024
7648f88
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
bc521c9
format
paul0403 Sep 17, 2024
8ed20c5
add support for pass options
paul0403 Sep 17, 2024
652d96a
format
paul0403 Sep 17, 2024
46f640f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
8d24767
type hint pipeline
paul0403 Sep 17, 2024
d410afb
no documentation for helpers in passes.py
paul0403 Sep 17, 2024
db3bd48
format
paul0403 Sep 17, 2024
551545d
add pytest for pass option effect
paul0403 Sep 17, 2024
6b20fc5
documentation
paul0403 Sep 17, 2024
6618e97
changelog
paul0403 Sep 17, 2024
72e6fc0
codefactor line too long in documentation
paul0403 Sep 17, 2024
a3b5425
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
be02cad
typo
paul0403 Sep 18, 2024
9195851
add quantum scope TODO
paul0403 Sep 18, 2024
0850516
double ticks in documentation instead of single tick for code words
paul0403 Sep 18, 2024
2eb6f2f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
d2e5abb
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
2ddf0af
rename variables in changelog
paul0403 Sep 18, 2024
8ae031f
change name uniquer to an import
paul0403 Sep 18, 2024
536d651
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
f0a5bb6
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
1511e17
remove "merge_rotations" from public documentation
paul0403 Sep 19, 2024
7c2142e
typo
paul0403 Sep 19, 2024
81b5006
add web link to `catalyst.passes` in documentation
paul0403 Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@
Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64))
```

* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline for QNodes within a qjit-compiled workflow to be configured.
[(#1131)](https://github.com/PennyLaneAI/catalyst/pull/1131)

```python
my_passes = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
dev = qml.device("lightning.qubit", wires=2)

@pipeline(my_passes)
@qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

@qjit
def fn(x):
return jnp.sin(circuit(x ** 2))
```

`pipeline` can also be used to specify different pass pipelines for different parts of the
same qjit-compiled workflow:

```python
my_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}

my_other_pipeline = {"cancel_inverses": {}}

@qjit
def fn(x):
circuit_pipeline = pipeline(my_pipeline)(circuit)
circuit_other = pipeline(my_other_pipeline)(circuit)
return jnp.abs(circuit_pipeline(x) - circuit_other(x))
```

For a list of available passes, please see the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

The pass pipeline order and options can be configured *globally* for a
qjit-compiled function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit` decorator.

```python
my_passes = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}

@qjit(circuit_transform_pipeline=my_passes)
def fn(x):
return jnp.sin(circuit(x ** 2))
```

Global and local (via `@pipeline`) configurations can coexist, however local pass pipelines
will always take precedence over global pass pipelines.

Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

<h3>Improvements</h3>

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`,
Expand Down
5 changes: 5 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class CompileOptions:
experimental_capture (bool): If set to ``True``,
use PennyLane's experimental program capture capabilities
to capture the function for compilation.
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action here, but @josh146 @isaacdevlugt this will soon beat the number of qnode's arguments.

A dictionary that specifies the quantum circuit transformation pass pipeline order,
and optionally arguments for each pass in the pipeline.
Default is None.
"""

verbose: Optional[bool] = False
Expand All @@ -94,6 +98,7 @@ class CompileOptions:
disable_assertions: Optional[bool] = False
seed: Optional[int] = None
experimental_capture: Optional[bool] = False
circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None

def __post_init__(self):
# Check that async runs must not be seeded
Expand Down
20 changes: 18 additions & 2 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from catalyst.from_plxpr import trace_from_pennylane
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
from catalyst.passes import PipelineNameUniquer, _inject_transform_named_sequence
from catalyst.qfunc import QFunc
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
Expand Down Expand Up @@ -85,6 +85,7 @@ def qjit(
disable_assertions=False,
seed=None,
experimental_capture=False,
circuit_transform_pipeline=None,
): # pylint: disable=too-many-arguments,unused-argument
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.
Expand Down Expand Up @@ -141,6 +142,15 @@ def qjit(
experimental_capture (bool): If set to ``True``, the qjit decorator
will use PennyLane's experimental program capture capabilities
to capture the decorated function for compilation.
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
A dictionary that specifies the quantum circuit transformation pass pipeline order,
and optionally arguments for each pass in the pipeline. Keys of this dictionary
should correspond to names of passes found in the `catalyst.passes <https://docs.
pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes>`_
module, values should either be empty dictionaries (for default pass options) or
dictionaries of valid keyword arguments and values for the specific pass.
The order of keys in this dictionary will determine the pass pipeline.
If not specified, the default pass pipeline will be applied.
Returns:
QJIT object.
Expand Down Expand Up @@ -609,7 +619,12 @@ def closure(qnode, *args, **kwargs):
params = {}
params["static_argnums"] = kwargs.pop("static_argnums", static_argnums)
params["_out_tree_expected"] = []
return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
return QFunc.__call__(
qnode,
pass_pipeline=self.compile_options.circuit_transform_pipeline,
*args,
**dict(params, **kwargs),
)

with Patcher(
(qml.QNode, "__call__", closure),
Expand All @@ -623,6 +638,7 @@ def closure(qnode, *args, **kwargs):
kwargs,
)

PipelineNameUniquer.reset()
return jaxpr, out_type, treedef, dynamic_sig

@instrument(size_from=0, has_finegrained=True)
Expand Down
188 changes: 173 additions & 15 deletions frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,155 @@
"""

import copy
import functools
from typing import Optional

import pennylane as qml

from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p
from catalyst.tracing.contexts import EvaluationContext


## API ##
# pylint: disable=line-too-long
def pipeline(fn=None, *, pass_pipeline: Optional[dict[str, dict[str, str]]] = None):
"""Configures the Catalyst MLIR pass pipeline for quantum circuit transformations for a QNode within a qjit-compiled program.
Args:
fn (QNode): The QNode to run the pass pipeline on.
pass_pipeline (dict[str, dict[str, str]]): A dictionary that specifies the pass pipeline order, and optionally
arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes
found in the `catalyst.passes <https://docs.pennylane.ai/projects/catalyst/en/stable/code
/__init__.html#module-catalyst.passes>`_ module, values should either be empty dictionaries
(for default pass options) or dictionaries of valid keyword arguments and values for the specific pass.
The order of keys in this dictionary will determine the pass pipeline.
If not specified, the default pass pipeline will be applied.
Returns:
~.QNode:
For a list of available passes, please see the :doc:`catalyst.passes module <code/passes>`.
The default pass pipeline when used with Catalyst is currently empty.
**Example**
``pipeline`` can be used to configure the pass pipeline order and options
of a QNode within a qjit-compiled function.
Configuration options are passed to specific passes via dictionaries:
.. code-block:: python
my_pass_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
@pipeline(my_pass_pipeline)
@qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
@qjit
def fn(x):
return jnp.sin(circuit(x ** 2))
``pipeline`` can also be used to specify different pass pipelines for different parts of the
same qjit-compiled workflow:
.. code-block:: python
my_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
my_other_pipeline = {"cancel_inverses": {}}
@qjit
def fn(x):
circuit_pipeline = pipeline(my_pipeline)(circuit)
circuit_other = pipeline(my_other_pipeline)(circuit)
return jnp.abs(circuit_pipeline(x) - circuit_other(x))
.. note::
As of Python 3.7, the CPython dictionary implementation orders dictionaries based on
insertion order. However, for an API gaurantee of dictionary order, ``collections.OrderedDict``
may also be used.
Note that the pass pipeline order and options can be configured *globally* for a
qjit-compiled function, by using the ``circuit_transform_pipeline`` argument of
the :func:`~.qjit` decorator.
.. code-block:: python
my_pass_pipeline = {
"cancel_inverses": {},
"my_circuit_transformation_pass": {"my-option" : "my-option-value"},
}
@qjit(circuit_transform_pipeline=my_pass_pipeline)
def fn(x):
return jnp.sin(circuit(x ** 2))
Global and local (via ``@pipeline``) configurations can coexist, however local pass pipelines
mehrdad2m marked this conversation as resolved.
Show resolved Hide resolved
will always take precedence over global pass pipelines.
"""

kwargs = copy.copy(locals())
kwargs.pop("fn")

if fn is None:
return functools.partial(pipeline, **kwargs)

if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

if pass_pipeline is None:
# TODO: design a default peephole pipeline
return fn

fn_original_name = fn.__name__
wrapped_qnode_function = fn.func
fn_clone = copy.copy(fn)
uniquer = str(_rename_to_unique())
fn_clone.__name__ = fn_original_name + "_transformed" + uniquer

pass_names = _API_name_to_pass_name()

def wrapper(*args, **kwrags):
# TODO: we should not match pass targets by function name.
# The quantum scope work will likely put each qnode into a module
# instead of a `func.func ... attributes {qnode}`.
# When that is in place, the qnode's module can have a proper attribute
# (as opposed to discardable) that records its transform schedule, i.e.
# module_with_transform @name_of_module {
# // transform schedule
# } {
# // contents of the module
# }
# This eliminates the need for matching target functions by name.

if EvaluationContext.is_tracing():
for API_name, pass_options in pass_pipeline.items():
opt = ""
for option, option_value in pass_options.items():
opt += " " + str(option) + "=" + str(option_value)
apply_registered_pass_p.bind(
pass_name=pass_names[API_name],
options=f"func-name={fn_original_name}" + "_transformed" + uniquer + opt,
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone.func = wrapper
fn_clone._peephole_transformed = True # pylint: disable=protected-access
tzunghanjuang marked this conversation as resolved.
Show resolved Hide resolved

return fn_clone


def cancel_inverses(fn=None):
"""
Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass
Expand Down Expand Up @@ -150,33 +291,50 @@ def circuit(x: float):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func
funcname = fn.__name__
wrapped_qnode_function = fn.func
uniquer = str(_rename_to_unique())

def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses" + uniquer,
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"
fn_clone.__name__ = funcname + "_cancel_inverses" + uniquer

return fn_clone


## IMPL and helpers ##
# pylint: disable=missing-function-docstring
class _PipelineNameUniquer:
def __init__(self, i):
self.i = i

def get(self):
self.i += 1
return self.i

def reset(self):
self.i = -1


PipelineNameUniquer = _PipelineNameUniquer(-1)


def _rename_to_unique():
return PipelineNameUniquer.get()


def _API_name_to_pass_name():
return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"}


def _inject_transform_named_sequence():
"""
Inject a transform_named_sequence jax primitive.
Expand Down
9 changes: 9 additions & 0 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
from catalyst.logging import debug_logger
from catalyst.passes import pipeline
from catalyst.tracing.type_signatures import filter_static_args

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,10 +93,18 @@ def __new__(cls):
raise NotImplementedError() # pragma: no-cover

# pylint: disable=no-member
# pylint: disable=self-cls-assignment
@debug_logger
def __call__(self, *args, **kwargs):
assert isinstance(self, qml.QNode)

# Update the qnode with peephole pipeline
if "pass_pipeline" in kwargs.keys():
pass_pipeline = kwargs["pass_pipeline"]
if not hasattr(self, "_peephole_transformed"):
self = pipeline(pass_pipeline=pass_pipeline)(self)
kwargs.pop("pass_pipeline")

# Mid-circuit measurement configuration/execution
dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False)
if not dynamic_one_shot_called:
Expand Down
Loading