From 5cf6bd46422653fe1dde798926bf556532516d2b Mon Sep 17 00:00:00 2001 From: paul0403 <79805239+paul0403@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:48:25 -0400 Subject: [PATCH] Create a frontend UI for users to specify quantum compilation pipelines (#1131) **Context:** We wish to create a frontend `pipeline` function for users to specify what circuit transformation passes to run in the dictionary form: https://app.shortcut.com/xanaduai/epic/67492/p1-a-ui-is-available-for-users-to-specify-quantum-compilation-pipelines-from-python **Description of the Change:** A new API decorator function `catalyst.passes.pipeline` can be applied to qnodes. The decorator takes in a dictionary specifying what circuit transformation passes to run on the qnode with what options, and adds the passes to the `transform_named_sequence`. In addition, `qjit` now takes in a kwarg `circuit_transform_pipeline` which expects the same pipeline dictionary. This will apply the pipeline to all qnodes in the qjit. To test the pipeline, we added the boilerplate for a merge_rotation pass with a pass option. The pass is currently empty and does nothing. The merge_rotation boilerplate pass is not exposed as a user-facing API, and only exists for the purpose of testing the pipeline. **Benefits:** User can specify pass pipelines. **Possible Drawbacks:** There are two items of improvements possible: 1. The target qnode of the pass is recorded by name. This is not optimal. The quantum scope work will likely put each qnode into a module instead of a `func.func ... attributes {qnode}` in mlir. When that is in place, the qnode's module can have a proper attribute (as opposed to discardable) that records its transform schedule, i.e. ``` module_with_transform @name_of_module { // transform schedule } { // contents of the module } ``` This eliminates the need for matching target functions by name. 2. The number of `qjit` kwargs is maybe too many. As of now we implement the syntax specified by the epic, but there could be an alternate design. [sc-67520] --- doc/releases/changelog-dev.md | 61 ++++ frontend/catalyst/compiler.py | 5 + frontend/catalyst/jit.py | 20 +- frontend/catalyst/passes.py | 188 +++++++++++- frontend/catalyst/qfunc.py | 9 + .../test/lit/test_peephole_optimizations.py | 285 +++++++++++++++++- .../pytest/test_peephole_optimizations.py | 44 ++- mlir/include/Quantum/Transforms/Passes.h | 1 + mlir/include/Quantum/Transforms/Passes.td | 23 +- .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + .../lib/Quantum/Transforms/merge_rotation.cpp | 58 ++++ 12 files changed, 664 insertions(+), 32 deletions(-) create mode 100644 mlir/lib/Quantum/Transforms/merge_rotation.cpp diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2d4a8c329e..b5b904d0bb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -53,6 +53,67 @@ Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64)) ``` +* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline for QNodes within a qjit-compiled workflow to be configured. + [(#1131)](https://github.com/PennyLaneAI/catalyst/pull/1131) + + ```python + my_passes = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + dev = qml.device("lightning.qubit", wires=2) + + @pipeline(my_passes) + @qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + @qjit + def fn(x): + return jnp.sin(circuit(x ** 2)) + ``` + + `pipeline` can also be used to specify different pass pipelines for different parts of the + same qjit-compiled workflow: + + ```python + my_pipeline = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + my_other_pipeline = {"cancel_inverses": {}} + + @qjit + def fn(x): + circuit_pipeline = pipeline(my_pipeline)(circuit) + circuit_other = pipeline(my_other_pipeline)(circuit) + return jnp.abs(circuit_pipeline(x) - circuit_other(x)) + ``` + + For a list of available passes, please see the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes). + + The pass pipeline order and options can be configured *globally* for a + qjit-compiled function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit` decorator. + + ```python + my_passes = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + @qjit(circuit_transform_pipeline=my_passes) + def fn(x): + return jnp.sin(circuit(x ** 2)) + ``` + + Global and local (via `@pipeline`) configurations can coexist, however local pass pipelines + will always take precedence over global pass pipelines. + + Available MLIR passes are now documented and available within the + [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes). +
`.
+
+ The default pass pipeline when used with Catalyst is currently empty.
+
+ **Example**
+
+ ``pipeline`` can be used to configure the pass pipeline order and options
+ of a QNode within a qjit-compiled function.
+
+ Configuration options are passed to specific passes via dictionaries:
+
+ .. code-block:: python
+
+ my_pass_pipeline = {
+ "cancel_inverses": {},
+ "my_circuit_transformation_pass": {"my-option" : "my-option-value"},
+ }
+
+ @pipeline(my_pass_pipeline)
+ @qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ @qjit
+ def fn(x):
+ return jnp.sin(circuit(x ** 2))
+
+ ``pipeline`` can also be used to specify different pass pipelines for different parts of the
+ same qjit-compiled workflow:
+
+ .. code-block:: python
+
+ my_pipeline = {
+ "cancel_inverses": {},
+ "my_circuit_transformation_pass": {"my-option" : "my-option-value"},
+ }
+
+ my_other_pipeline = {"cancel_inverses": {}}
+
+ @qjit
+ def fn(x):
+ circuit_pipeline = pipeline(my_pipeline)(circuit)
+ circuit_other = pipeline(my_other_pipeline)(circuit)
+ return jnp.abs(circuit_pipeline(x) - circuit_other(x))
+
+ .. note::
+
+ As of Python 3.7, the CPython dictionary implementation orders dictionaries based on
+ insertion order. However, for an API gaurantee of dictionary order, ``collections.OrderedDict``
+ may also be used.
+
+ Note that the pass pipeline order and options can be configured *globally* for a
+ qjit-compiled function, by using the ``circuit_transform_pipeline`` argument of
+ the :func:`~.qjit` decorator.
+
+ .. code-block:: python
+
+ my_pass_pipeline = {
+ "cancel_inverses": {},
+ "my_circuit_transformation_pass": {"my-option" : "my-option-value"},
+ }
+
+ @qjit(circuit_transform_pipeline=my_pass_pipeline)
+ def fn(x):
+ return jnp.sin(circuit(x ** 2))
+
+ Global and local (via ``@pipeline``) configurations can coexist, however local pass pipelines
+ will always take precedence over global pass pipelines.
+ """
+
+ kwargs = copy.copy(locals())
+ kwargs.pop("fn")
+
+ if fn is None:
+ return functools.partial(pipeline, **kwargs)
+
+ if not isinstance(fn, qml.QNode):
+ raise TypeError(f"A QNode is expected, got the classical function {fn}")
+
+ if pass_pipeline is None:
+ # TODO: design a default peephole pipeline
+ return fn
+
+ fn_original_name = fn.__name__
+ wrapped_qnode_function = fn.func
+ fn_clone = copy.copy(fn)
+ uniquer = str(_rename_to_unique())
+ fn_clone.__name__ = fn_original_name + "_transformed" + uniquer
+
+ pass_names = _API_name_to_pass_name()
+
+ def wrapper(*args, **kwrags):
+ # TODO: we should not match pass targets by function name.
+ # The quantum scope work will likely put each qnode into a module
+ # instead of a `func.func ... attributes {qnode}`.
+ # When that is in place, the qnode's module can have a proper attribute
+ # (as opposed to discardable) that records its transform schedule, i.e.
+ # module_with_transform @name_of_module {
+ # // transform schedule
+ # } {
+ # // contents of the module
+ # }
+ # This eliminates the need for matching target functions by name.
+
+ if EvaluationContext.is_tracing():
+ for API_name, pass_options in pass_pipeline.items():
+ opt = ""
+ for option, option_value in pass_options.items():
+ opt += " " + str(option) + "=" + str(option_value)
+ apply_registered_pass_p.bind(
+ pass_name=pass_names[API_name],
+ options=f"func-name={fn_original_name}" + "_transformed" + uniquer + opt,
+ )
+ return wrapped_qnode_function(*args, **kwrags)
+
+ fn_clone.func = wrapper
+ fn_clone._peephole_transformed = True # pylint: disable=protected-access
+
+ return fn_clone
+
+
def cancel_inverses(fn=None):
"""
Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass
@@ -150,33 +291,50 @@ def circuit(x: float):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")
- wrapped_qnode_function = fn.func
funcname = fn.__name__
+ wrapped_qnode_function = fn.func
+ uniquer = str(_rename_to_unique())
def wrapper(*args, **kwrags):
- # TODO: hint the compiler which qnodes to run the pass on via an func attribute,
- # instead of the qnode name. That way the clone can have this attribute and
- # the original can just not have it.
- # We are not doing this right now and passing by name because this would
- # be a discardable attribute (i.e. a user/developer wouldn't know that this
- # attribute exists just by looking at qnode's documentation)
- # But when we add the full peephole pipeline in the future, the attribute
- # could get properly documented.
-
- apply_registered_pass_p.bind(
- pass_name="remove-chained-self-inverse",
- options=f"func-name={funcname}" + "_cancel_inverses",
- )
+ if EvaluationContext.is_tracing():
+ apply_registered_pass_p.bind(
+ pass_name="remove-chained-self-inverse",
+ options=f"func-name={funcname}" + "_cancel_inverses" + uniquer,
+ )
return wrapped_qnode_function(*args, **kwrags)
fn_clone = copy.copy(fn)
fn_clone.func = wrapper
- fn_clone.__name__ = funcname + "_cancel_inverses"
+ fn_clone.__name__ = funcname + "_cancel_inverses" + uniquer
return fn_clone
## IMPL and helpers ##
+# pylint: disable=missing-function-docstring
+class _PipelineNameUniquer:
+ def __init__(self, i):
+ self.i = i
+
+ def get(self):
+ self.i += 1
+ return self.i
+
+ def reset(self):
+ self.i = -1
+
+
+PipelineNameUniquer = _PipelineNameUniquer(-1)
+
+
+def _rename_to_unique():
+ return PipelineNameUniquer.get()
+
+
+def _API_name_to_pass_name():
+ return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"}
+
+
def _inject_transform_named_sequence():
"""
Inject a transform_named_sequence jax primitive.
diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py
index d137106d74..bf33a745a9 100644
--- a/frontend/catalyst/qfunc.py
+++ b/frontend/catalyst/qfunc.py
@@ -49,6 +49,7 @@
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
from catalyst.logging import debug_logger
+from catalyst.passes import pipeline
from catalyst.tracing.type_signatures import filter_static_args
logger = logging.getLogger(__name__)
@@ -92,10 +93,18 @@ def __new__(cls):
raise NotImplementedError() # pragma: no-cover
# pylint: disable=no-member
+ # pylint: disable=self-cls-assignment
@debug_logger
def __call__(self, *args, **kwargs):
assert isinstance(self, qml.QNode)
+ # Update the qnode with peephole pipeline
+ if "pass_pipeline" in kwargs.keys():
+ pass_pipeline = kwargs["pass_pipeline"]
+ if not hasattr(self, "_peephole_transformed"):
+ self = pipeline(pass_pipeline=pass_pipeline)(self)
+ kwargs.pop("pass_pipeline")
+
# Mid-circuit measurement configuration/execution
dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False)
if not dynamic_one_shot_called:
diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py
index c9dab6de0a..dd5dfb833d 100644
--- a/frontend/test/lit/test_peephole_optimizations.py
+++ b/frontend/test/lit/test_peephole_optimizations.py
@@ -31,7 +31,7 @@
from catalyst import qjit
from catalyst.debug import get_compilation_stage
-from catalyst.passes import cancel_inverses
+from catalyst.passes import cancel_inverses, pipeline
def flush_peephole_opted_mlir_to_iostream(QJIT):
@@ -75,6 +75,267 @@ def func():
test_transform_named_sequence_injection()
+#
+# pipeline
+#
+
+
+def test_pipeline_lowering():
+ """
+ Basic pipeline lowering on one qnode.
+ """
+ my_pipeline = {
+ "cancel_inverses": {},
+ "merge_rotations": {"my-option": "aloha"},
+ }
+
+ @qjit(keep_intermediate=True)
+ @pipeline(pass_pipeline=my_pipeline)
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def test_pipeline_lowering_workflow(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ # CHECK: transform_named_sequence
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0
+ # CHECK: pass_name=remove-chained-self-inverse
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ print_jaxpr(test_pipeline_lowering_workflow, 1.2)
+
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha"}
+ # CHECK-NEXT: transform.yield
+ print_mlir(test_pipeline_lowering_workflow, 1.2)
+
+ # CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_transformed0(
+ # CHECK: func.func private @test_pipeline_lowering_workflow_transformed0
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ test_pipeline_lowering_workflow(42.42)
+ flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow)
+
+
+test_pipeline_lowering()
+
+
+def test_pipeline_lowering_keep_original():
+ """
+ Test when the pipelined qnode and the original qnode are both used,
+ and the original is correctly kept and untransformed.
+ """
+ my_pipeline = {
+ "cancel_inverses": {},
+ "merge_rotations": {},
+ }
+
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def f(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ f_pipeline = pipeline(pass_pipeline=my_pipeline)(f)
+
+ @qjit(keep_intermediate=True)
+ def test_pipeline_lowering_keep_original_workflow(x):
+ return f(x), f_pipeline(x)
+
+ # CHECK: transform_named_sequence
+ # CHECK: call_jaxpr=
+ # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: call_jaxpr=
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=f_transformed0
+ # CHECK: pass_name=remove-chained-self-inverse
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=f_transformed0
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2)
+
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=f_transformed0"}
+ # CHECK-NEXT: transform.yield
+ print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2)
+
+ # CHECK: func.func public @jit_test_pipeline_lowering_keep_original_workflow
+ # CHECK: {{%.+}} = call @f(
+ # CHECK: {{%.+}} = call @f_transformed0(
+ # CHECK: func.func private @f(
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK: func.func private @f_transformed0
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ test_pipeline_lowering_keep_original_workflow(42.42)
+ flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_keep_original_workflow)
+
+
+test_pipeline_lowering_keep_original()
+
+
+def test_pipeline_lowering_global():
+ """
+ Test that the global qjit circuit_transform_pipeline option
+ transforms all qnodes in the qjit.
+ """
+ my_pipeline = {
+ "cancel_inverses": {},
+ "merge_rotations": {},
+ }
+
+ @qjit(keep_intermediate=True, circuit_transform_pipeline=my_pipeline)
+ def global_wf():
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def g(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def h(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ return g(1.2), h(1.2)
+
+ # CHECK: transform_named_sequence
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=g_transformed0
+ # CHECK: pass_name=remove-chained-self-inverse
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=g_transformed0
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=h_transformed1
+ # CHECK: pass_name=remove-chained-self-inverse
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=h_transformed1
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ print_jaxpr(global_wf)
+
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed1"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed1"}
+ # CHECK-NEXT: transform.yield
+ print_mlir(global_wf)
+
+ # CHECK: func.func public @jit_global_wf()
+ # CHECK {{%.+}} = call @g_transformed0(
+ # CHECK {{%.+}} = call @h_transformed1(
+ # CHECK: func.func private @g_transformed0
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK: func.func private @h_transformed1
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ global_wf()
+ flush_peephole_opted_mlir_to_iostream(global_wf)
+
+
+test_pipeline_lowering_global()
+
+
+def test_pipeline_lowering_globloc_override():
+ """
+ Test that local qnode pipelines correctly overrides the global
+ pipeline specified by the qjit's option.
+ """
+ global_pipeline = {
+ "cancel_inverses": {},
+ "merge_rotations": {},
+ }
+
+ local_pipeline = {
+ "merge_rotations": {},
+ }
+
+ @qjit(keep_intermediate=True, circuit_transform_pipeline=global_pipeline)
+ def global_wf():
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def g(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ @pipeline(pass_pipeline=local_pipeline)
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def h(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ return g(1.2), h(1.2)
+
+ # CHECK: transform_named_sequence
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=g_transformed1
+ # CHECK: pass_name=remove-chained-self-inverse
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=g_transformed1
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ # CHECK: _:AbstractTransformMod() = apply_registered_pass[
+ # CHECK: options=func-name=h_transformed0
+ # CHECK-NOT: pass_name=remove-chained-self-inverse
+ # CHECK: pass_name=merge-rotation
+ # CHECK: ]
+ print_jaxpr(global_wf)
+
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed1"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed1"}
+ # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed0"}
+ # CHECK-NEXT: transform.yield
+ print_mlir(global_wf)
+
+ # CHECK: func.func public @jit_global_wf()
+ # CHECK {{%.+}} = call @g_transformed1(
+ # CHECK {{%.+}} = call @h_transformed0(
+ # CHECK: func.func private @g_transformed1
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK: func.func private @h_transformed0
+ # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
+ # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
+ global_wf()
+ flush_peephole_opted_mlir_to_iostream(global_wf)
+
+
+test_pipeline_lowering_globloc_override()
+
+
#
# cancel_inverses
#
@@ -118,11 +379,11 @@ def h(x: float):
# CHECK: transform_named_sequence
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
- # CHECK: options=func-name=f_cancel_inverses
+ # CHECK: options=func-name=f_cancel_inverses0
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
- # CHECK: options=func-name=g_cancel_inverses
+ # CHECK: options=func-name=g_cancel_inverses1
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
# CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[
@@ -132,8 +393,8 @@ def h(x: float):
# CHECK: module @test_cancel_inverses_tracing_and_lowering_workflow
# CHECK: transform.named_sequence @__transform_main
- # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"}
- # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses1"}
# CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_cancel_inverses"}
# CHECK-NEXT: transform.yield
print_mlir(test_cancel_inverses_tracing_and_lowering_workflow, 1.1)
@@ -162,14 +423,14 @@ def test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow(xx: float):
# CHECK: transform_named_sequence
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
- # CHECK: options=func-name=f_cancel_inverses
+ # CHECK: options=func-name=f_cancel_inverses0
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
print_jaxpr(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1)
# CHECK: module @test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow
# CHECK: transform.named_sequence @__transform_main
- # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"}
+ # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"}
# CHECK-NEXT: transform.yield
print_mlir(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1)
@@ -254,7 +515,7 @@ def f(x: float):
# CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow0
# CHECK: {{%.+}} = call @f({{%.+}})
- # CHECK-NOT: {{%.+}} = call @f_cancel_inverses({{%.+}})
+ # CHECK-NOT: {{%.+}} = call @f_cancel_inverses0({{%.+}})
# CHECK-LABEL: private @f({{%.+}})
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
@@ -268,9 +529,9 @@ def test_cancel_inverses_keep_original_workflow0():
flush_peephole_opted_mlir_to_iostream(test_cancel_inverses_keep_original_workflow0)
# CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow1
- # CHECK: {{%.+}} = call @f_cancel_inverses({{%.+}})
+ # CHECK: {{%.+}} = call @f_cancel_inverses0({{%.+}})
# CHECK-NOT: {{%.+}} = call @f({{%.+}})
- # CHECK-LABEL: private @f_cancel_inverses({{%.+}})
+ # CHECK-LABEL: private @f_cancel_inverses0({{%.+}})
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
@@ -284,12 +545,12 @@ def test_cancel_inverses_keep_original_workflow1():
# CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow2
# CHECK: {{%.+}} = call @f({{%.+}})
- # CHECK: {{%.+}} = call @f_cancel_inverses({{%.+}})
+ # CHECK: {{%.+}} = call @f_cancel_inverses0({{%.+}})
# CHECK-LABEL: private @f({{%.+}})
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
- # CHECK-LABEL: private @f_cancel_inverses({{%.+}})
+ # CHECK-LABEL: private @f_cancel_inverses0({{%.+}})
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py
index 810f254c11..40bbf0fb28 100644
--- a/frontend/test/pytest/test_peephole_optimizations.py
+++ b/frontend/test/pytest/test_peephole_optimizations.py
@@ -19,7 +19,7 @@
import pytest
from catalyst import qjit
-from catalyst.passes import cancel_inverses
+from catalyst.passes import cancel_inverses, pipeline
# pylint: disable=missing-function-docstring
@@ -91,6 +91,42 @@ def g(x):
assert np.allclose(workflow()[0], workflow()[1])
+@pytest.mark.parametrize("theta", [42.42])
+def test_pipeline_functionality(capfd, theta, backend):
+ """
+ Test that the @pipeline decorator does not change functionality
+ when all the passes in the pipeline does not change functionality.
+ """
+ my_pipeline = {
+ "cancel_inverses": {},
+ "merge_rotations": {"my-option": "aloha"},
+ }
+
+ @qjit
+ def workflow():
+ @qml.qnode(qml.device(backend, wires=2))
+ def f(x):
+ qml.RX(x, wires=[0])
+ qml.Hadamard(wires=[1])
+ qml.Hadamard(wires=[1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ no_pipeline_result = f(theta)
+ pipeline_result = pipeline(pass_pipeline=my_pipeline)(f)(theta)
+
+ return no_pipeline_result, pipeline_result
+
+ res = workflow()
+ assert np.allclose(res[0], res[1])
+
+ # TODO: the boilerplate merge rotation pass prints out different messages based on
+ # the pass option.
+ # The purpose is to test the integration of pass options with pipeline decorator.
+ # Remove the string check when merge rotation becomes the actual merge rotation pass.
+ output_message = capfd.readouterr().err
+ assert output_message == "merge rotation pass, aloha!\n"
+
+
### Test bad usages of pass decorators ###
def test_cancel_inverses_bad_usages():
"""
@@ -101,6 +137,12 @@ def test_cancel_inverses_not_on_qnode():
def classical_func():
return 42.42
+ with pytest.raises(
+ TypeError,
+ match="A QNode is expected, got the classical function",
+ ):
+ pipeline(classical_func)
+
with pytest.raises(
TypeError,
match="A QNode is expected, got the classical function",
diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h
index 96c748bd49..2e241bcd7d 100644
--- a/mlir/include/Quantum/Transforms/Passes.h
+++ b/mlir/include/Quantum/Transforms/Passes.h
@@ -28,5 +28,6 @@ std::unique_ptr createAdjointLoweringPass();
std::unique_ptr createRemoveChainedSelfInversePass();
std::unique_ptr createAnnotateFunctionPass();
std::unique_ptr createSplitMultipleTapesPass();
+std::unique_ptr createMergeRotationPass();
} // namespace catalyst
diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td
index 0e720a947e..202347c150 100644
--- a/mlir/include/Quantum/Transforms/Passes.td
+++ b/mlir/include/Quantum/Transforms/Passes.td
@@ -91,7 +91,11 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> {
// ----- Quantum circuit transformation passes begin ----- //
// For example, automatic compiler peephole opts, etc.
-class QuantumCircuitTransformationPass : Pass {
+class QuantumCircuitTransformationPassBase {
+ list