From 5cf6bd46422653fe1dde798926bf556532516d2b Mon Sep 17 00:00:00 2001 From: paul0403 <79805239+paul0403@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:48:25 -0400 Subject: [PATCH] Create a frontend UI for users to specify quantum compilation pipelines (#1131) **Context:** We wish to create a frontend `pipeline` function for users to specify what circuit transformation passes to run in the dictionary form: https://app.shortcut.com/xanaduai/epic/67492/p1-a-ui-is-available-for-users-to-specify-quantum-compilation-pipelines-from-python **Description of the Change:** A new API decorator function `catalyst.passes.pipeline` can be applied to qnodes. The decorator takes in a dictionary specifying what circuit transformation passes to run on the qnode with what options, and adds the passes to the `transform_named_sequence`. In addition, `qjit` now takes in a kwarg `circuit_transform_pipeline` which expects the same pipeline dictionary. This will apply the pipeline to all qnodes in the qjit. To test the pipeline, we added the boilerplate for a merge_rotation pass with a pass option. The pass is currently empty and does nothing. The merge_rotation boilerplate pass is not exposed as a user-facing API, and only exists for the purpose of testing the pipeline. **Benefits:** User can specify pass pipelines. **Possible Drawbacks:** There are two items of improvements possible: 1. The target qnode of the pass is recorded by name. This is not optimal. The quantum scope work will likely put each qnode into a module instead of a `func.func ... attributes {qnode}` in mlir. When that is in place, the qnode's module can have a proper attribute (as opposed to discardable) that records its transform schedule, i.e. ``` module_with_transform @name_of_module { // transform schedule } { // contents of the module } ``` This eliminates the need for matching target functions by name. 2. The number of `qjit` kwargs is maybe too many. As of now we implement the syntax specified by the epic, but there could be an alternate design. [sc-67520] --- doc/releases/changelog-dev.md | 61 ++++ frontend/catalyst/compiler.py | 5 + frontend/catalyst/jit.py | 20 +- frontend/catalyst/passes.py | 188 +++++++++++- frontend/catalyst/qfunc.py | 9 + .../test/lit/test_peephole_optimizations.py | 285 +++++++++++++++++- .../pytest/test_peephole_optimizations.py | 44 ++- mlir/include/Quantum/Transforms/Passes.h | 1 + mlir/include/Quantum/Transforms/Passes.td | 23 +- .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + .../lib/Quantum/Transforms/merge_rotation.cpp | 58 ++++ 12 files changed, 664 insertions(+), 32 deletions(-) create mode 100644 mlir/lib/Quantum/Transforms/merge_rotation.cpp diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2d4a8c329e..b5b904d0bb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -53,6 +53,67 @@ Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64)) ``` +* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline for QNodes within a qjit-compiled workflow to be configured. + [(#1131)](https://github.com/PennyLaneAI/catalyst/pull/1131) + + ```python + my_passes = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + dev = qml.device("lightning.qubit", wires=2) + + @pipeline(my_passes) + @qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + @qjit + def fn(x): + return jnp.sin(circuit(x ** 2)) + ``` + + `pipeline` can also be used to specify different pass pipelines for different parts of the + same qjit-compiled workflow: + + ```python + my_pipeline = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + my_other_pipeline = {"cancel_inverses": {}} + + @qjit + def fn(x): + circuit_pipeline = pipeline(my_pipeline)(circuit) + circuit_other = pipeline(my_other_pipeline)(circuit) + return jnp.abs(circuit_pipeline(x) - circuit_other(x)) + ``` + + For a list of available passes, please see the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes). + + The pass pipeline order and options can be configured *globally* for a + qjit-compiled function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit` decorator. + + ```python + my_passes = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + @qjit(circuit_transform_pipeline=my_passes) + def fn(x): + return jnp.sin(circuit(x ** 2)) + ``` + + Global and local (via `@pipeline`) configurations can coexist, however local pass pipelines + will always take precedence over global pass pipelines. + + Available MLIR passes are now documented and available within the + [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes). +

Improvements

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`, diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 31f68cfc66..67278f7c3c 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -77,6 +77,10 @@ class CompileOptions: experimental_capture (bool): If set to ``True``, use PennyLane's experimental program capture capabilities to capture the function for compilation. + circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]): + A dictionary that specifies the quantum circuit transformation pass pipeline order, + and optionally arguments for each pass in the pipeline. + Default is None. """ verbose: Optional[bool] = False @@ -94,6 +98,7 @@ class CompileOptions: disable_assertions: Optional[bool] = False seed: Optional[int] = None experimental_capture: Optional[bool] = False + circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None def __post_init__(self): # Check that async runs must not be seeded diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 70225a971b..df208228c6 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -38,7 +38,7 @@ from catalyst.from_plxpr import trace_from_pennylane from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr from catalyst.logging import debug_logger, debug_logger_init -from catalyst.passes import _inject_transform_named_sequence +from catalyst.passes import PipelineNameUniquer, _inject_transform_named_sequence from catalyst.qfunc import QFunc from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import ( @@ -85,6 +85,7 @@ def qjit( disable_assertions=False, seed=None, experimental_capture=False, + circuit_transform_pipeline=None, ): # pylint: disable=too-many-arguments,unused-argument """A just-in-time decorator for PennyLane and JAX programs using Catalyst. @@ -141,6 +142,15 @@ def qjit( experimental_capture (bool): If set to ``True``, the qjit decorator will use PennyLane's experimental program capture capabilities to capture the decorated function for compilation. + circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]): + A dictionary that specifies the quantum circuit transformation pass pipeline order, + and optionally arguments for each pass in the pipeline. Keys of this dictionary + should correspond to names of passes found in the `catalyst.passes `_ + module, values should either be empty dictionaries (for default pass options) or + dictionaries of valid keyword arguments and values for the specific pass. + The order of keys in this dictionary will determine the pass pipeline. + If not specified, the default pass pipeline will be applied. Returns: QJIT object. @@ -609,7 +619,12 @@ def closure(qnode, *args, **kwargs): params = {} params["static_argnums"] = kwargs.pop("static_argnums", static_argnums) params["_out_tree_expected"] = [] - return QFunc.__call__(qnode, *args, **dict(params, **kwargs)) + return QFunc.__call__( + qnode, + pass_pipeline=self.compile_options.circuit_transform_pipeline, + *args, + **dict(params, **kwargs), + ) with Patcher( (qml.QNode, "__call__", closure), @@ -623,6 +638,7 @@ def closure(qnode, *args, **kwargs): kwargs, ) + PipelineNameUniquer.reset() return jaxpr, out_type, treedef, dynamic_sig @instrument(size_from=0, has_finegrained=True) diff --git a/frontend/catalyst/passes.py b/frontend/catalyst/passes.py index 6d3b2cd5a0..86dc5e7caf 100644 --- a/frontend/catalyst/passes.py +++ b/frontend/catalyst/passes.py @@ -33,14 +33,155 @@ """ import copy +import functools +from typing import Optional import pennylane as qml from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p +from catalyst.tracing.contexts import EvaluationContext ## API ## # pylint: disable=line-too-long +def pipeline(fn=None, *, pass_pipeline: Optional[dict[str, dict[str, str]]] = None): + """Configures the Catalyst MLIR pass pipeline for quantum circuit transformations for a QNode within a qjit-compiled program. + + Args: + fn (QNode): The QNode to run the pass pipeline on. + pass_pipeline (dict[str, dict[str, str]]): A dictionary that specifies the pass pipeline order, and optionally + arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes + found in the `catalyst.passes `_ module, values should either be empty dictionaries + (for default pass options) or dictionaries of valid keyword arguments and values for the specific pass. + The order of keys in this dictionary will determine the pass pipeline. + If not specified, the default pass pipeline will be applied. + + Returns: + ~.QNode: + + For a list of available passes, please see the :doc:`catalyst.passes module `. + + The default pass pipeline when used with Catalyst is currently empty. + + **Example** + + ``pipeline`` can be used to configure the pass pipeline order and options + of a QNode within a qjit-compiled function. + + Configuration options are passed to specific passes via dictionaries: + + .. code-block:: python + + my_pass_pipeline = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + @pipeline(my_pass_pipeline) + @qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + @qjit + def fn(x): + return jnp.sin(circuit(x ** 2)) + + ``pipeline`` can also be used to specify different pass pipelines for different parts of the + same qjit-compiled workflow: + + .. code-block:: python + + my_pipeline = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + my_other_pipeline = {"cancel_inverses": {}} + + @qjit + def fn(x): + circuit_pipeline = pipeline(my_pipeline)(circuit) + circuit_other = pipeline(my_other_pipeline)(circuit) + return jnp.abs(circuit_pipeline(x) - circuit_other(x)) + + .. note:: + + As of Python 3.7, the CPython dictionary implementation orders dictionaries based on + insertion order. However, for an API gaurantee of dictionary order, ``collections.OrderedDict`` + may also be used. + + Note that the pass pipeline order and options can be configured *globally* for a + qjit-compiled function, by using the ``circuit_transform_pipeline`` argument of + the :func:`~.qjit` decorator. + + .. code-block:: python + + my_pass_pipeline = { + "cancel_inverses": {}, + "my_circuit_transformation_pass": {"my-option" : "my-option-value"}, + } + + @qjit(circuit_transform_pipeline=my_pass_pipeline) + def fn(x): + return jnp.sin(circuit(x ** 2)) + + Global and local (via ``@pipeline``) configurations can coexist, however local pass pipelines + will always take precedence over global pass pipelines. + """ + + kwargs = copy.copy(locals()) + kwargs.pop("fn") + + if fn is None: + return functools.partial(pipeline, **kwargs) + + if not isinstance(fn, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {fn}") + + if pass_pipeline is None: + # TODO: design a default peephole pipeline + return fn + + fn_original_name = fn.__name__ + wrapped_qnode_function = fn.func + fn_clone = copy.copy(fn) + uniquer = str(_rename_to_unique()) + fn_clone.__name__ = fn_original_name + "_transformed" + uniquer + + pass_names = _API_name_to_pass_name() + + def wrapper(*args, **kwrags): + # TODO: we should not match pass targets by function name. + # The quantum scope work will likely put each qnode into a module + # instead of a `func.func ... attributes {qnode}`. + # When that is in place, the qnode's module can have a proper attribute + # (as opposed to discardable) that records its transform schedule, i.e. + # module_with_transform @name_of_module { + # // transform schedule + # } { + # // contents of the module + # } + # This eliminates the need for matching target functions by name. + + if EvaluationContext.is_tracing(): + for API_name, pass_options in pass_pipeline.items(): + opt = "" + for option, option_value in pass_options.items(): + opt += " " + str(option) + "=" + str(option_value) + apply_registered_pass_p.bind( + pass_name=pass_names[API_name], + options=f"func-name={fn_original_name}" + "_transformed" + uniquer + opt, + ) + return wrapped_qnode_function(*args, **kwrags) + + fn_clone.func = wrapper + fn_clone._peephole_transformed = True # pylint: disable=protected-access + + return fn_clone + + def cancel_inverses(fn=None): """ Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass @@ -150,33 +291,50 @@ def circuit(x: float): if not isinstance(fn, qml.QNode): raise TypeError(f"A QNode is expected, got the classical function {fn}") - wrapped_qnode_function = fn.func funcname = fn.__name__ + wrapped_qnode_function = fn.func + uniquer = str(_rename_to_unique()) def wrapper(*args, **kwrags): - # TODO: hint the compiler which qnodes to run the pass on via an func attribute, - # instead of the qnode name. That way the clone can have this attribute and - # the original can just not have it. - # We are not doing this right now and passing by name because this would - # be a discardable attribute (i.e. a user/developer wouldn't know that this - # attribute exists just by looking at qnode's documentation) - # But when we add the full peephole pipeline in the future, the attribute - # could get properly documented. - - apply_registered_pass_p.bind( - pass_name="remove-chained-self-inverse", - options=f"func-name={funcname}" + "_cancel_inverses", - ) + if EvaluationContext.is_tracing(): + apply_registered_pass_p.bind( + pass_name="remove-chained-self-inverse", + options=f"func-name={funcname}" + "_cancel_inverses" + uniquer, + ) return wrapped_qnode_function(*args, **kwrags) fn_clone = copy.copy(fn) fn_clone.func = wrapper - fn_clone.__name__ = funcname + "_cancel_inverses" + fn_clone.__name__ = funcname + "_cancel_inverses" + uniquer return fn_clone ## IMPL and helpers ## +# pylint: disable=missing-function-docstring +class _PipelineNameUniquer: + def __init__(self, i): + self.i = i + + def get(self): + self.i += 1 + return self.i + + def reset(self): + self.i = -1 + + +PipelineNameUniquer = _PipelineNameUniquer(-1) + + +def _rename_to_unique(): + return PipelineNameUniquer.get() + + +def _API_name_to_pass_name(): + return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"} + + def _inject_transform_named_sequence(): """ Inject a transform_named_sequence jax primitive. diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index d137106d74..bf33a745a9 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -49,6 +49,7 @@ from catalyst.jax_primitives import func_p from catalyst.jax_tracer import trace_quantum_function from catalyst.logging import debug_logger +from catalyst.passes import pipeline from catalyst.tracing.type_signatures import filter_static_args logger = logging.getLogger(__name__) @@ -92,10 +93,18 @@ def __new__(cls): raise NotImplementedError() # pragma: no-cover # pylint: disable=no-member + # pylint: disable=self-cls-assignment @debug_logger def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) + # Update the qnode with peephole pipeline + if "pass_pipeline" in kwargs.keys(): + pass_pipeline = kwargs["pass_pipeline"] + if not hasattr(self, "_peephole_transformed"): + self = pipeline(pass_pipeline=pass_pipeline)(self) + kwargs.pop("pass_pipeline") + # Mid-circuit measurement configuration/execution dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False) if not dynamic_one_shot_called: diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index c9dab6de0a..dd5dfb833d 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -31,7 +31,7 @@ from catalyst import qjit from catalyst.debug import get_compilation_stage -from catalyst.passes import cancel_inverses +from catalyst.passes import cancel_inverses, pipeline def flush_peephole_opted_mlir_to_iostream(QJIT): @@ -75,6 +75,267 @@ def func(): test_transform_named_sequence_injection() +# +# pipeline +# + + +def test_pipeline_lowering(): + """ + Basic pipeline lowering on one qnode. + """ + my_pipeline = { + "cancel_inverses": {}, + "merge_rotations": {"my-option": "aloha"}, + } + + @qjit(keep_intermediate=True) + @pipeline(pass_pipeline=my_pipeline) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def test_pipeline_lowering_workflow(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + # CHECK: transform_named_sequence + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 + # CHECK: pass_name=remove-chained-self-inverse + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha + # CHECK: pass_name=merge-rotation + # CHECK: ] + print_jaxpr(test_pipeline_lowering_workflow, 1.2) + + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha"} + # CHECK-NEXT: transform.yield + print_mlir(test_pipeline_lowering_workflow, 1.2) + + # CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_transformed0( + # CHECK: func.func private @test_pipeline_lowering_workflow_transformed0 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + test_pipeline_lowering_workflow(42.42) + flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow) + + +test_pipeline_lowering() + + +def test_pipeline_lowering_keep_original(): + """ + Test when the pipelined qnode and the original qnode are both used, + and the original is correctly kept and untransformed. + """ + my_pipeline = { + "cancel_inverses": {}, + "merge_rotations": {}, + } + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def f(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + f_pipeline = pipeline(pass_pipeline=my_pipeline)(f) + + @qjit(keep_intermediate=True) + def test_pipeline_lowering_keep_original_workflow(x): + return f(x), f_pipeline(x) + + # CHECK: transform_named_sequence + # CHECK: call_jaxpr= + # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: call_jaxpr= + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=f_transformed0 + # CHECK: pass_name=remove-chained-self-inverse + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=f_transformed0 + # CHECK: pass_name=merge-rotation + # CHECK: ] + print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2) + + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=f_transformed0"} + # CHECK-NEXT: transform.yield + print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2) + + # CHECK: func.func public @jit_test_pipeline_lowering_keep_original_workflow + # CHECK: {{%.+}} = call @f( + # CHECK: {{%.+}} = call @f_transformed0( + # CHECK: func.func private @f( + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK: func.func private @f_transformed0 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + test_pipeline_lowering_keep_original_workflow(42.42) + flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_keep_original_workflow) + + +test_pipeline_lowering_keep_original() + + +def test_pipeline_lowering_global(): + """ + Test that the global qjit circuit_transform_pipeline option + transforms all qnodes in the qjit. + """ + my_pipeline = { + "cancel_inverses": {}, + "merge_rotations": {}, + } + + @qjit(keep_intermediate=True, circuit_transform_pipeline=my_pipeline) + def global_wf(): + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def g(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def h(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + return g(1.2), h(1.2) + + # CHECK: transform_named_sequence + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=g_transformed0 + # CHECK: pass_name=remove-chained-self-inverse + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=g_transformed0 + # CHECK: pass_name=merge-rotation + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=h_transformed1 + # CHECK: pass_name=remove-chained-self-inverse + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=h_transformed1 + # CHECK: pass_name=merge-rotation + # CHECK: ] + print_jaxpr(global_wf) + + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed1"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed1"} + # CHECK-NEXT: transform.yield + print_mlir(global_wf) + + # CHECK: func.func public @jit_global_wf() + # CHECK {{%.+}} = call @g_transformed0( + # CHECK {{%.+}} = call @h_transformed1( + # CHECK: func.func private @g_transformed0 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK: func.func private @h_transformed1 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + global_wf() + flush_peephole_opted_mlir_to_iostream(global_wf) + + +test_pipeline_lowering_global() + + +def test_pipeline_lowering_globloc_override(): + """ + Test that local qnode pipelines correctly overrides the global + pipeline specified by the qjit's option. + """ + global_pipeline = { + "cancel_inverses": {}, + "merge_rotations": {}, + } + + local_pipeline = { + "merge_rotations": {}, + } + + @qjit(keep_intermediate=True, circuit_transform_pipeline=global_pipeline) + def global_wf(): + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def g(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + @pipeline(pass_pipeline=local_pipeline) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def h(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + return g(1.2), h(1.2) + + # CHECK: transform_named_sequence + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=g_transformed1 + # CHECK: pass_name=remove-chained-self-inverse + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=g_transformed1 + # CHECK: pass_name=merge-rotation + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=h_transformed0 + # CHECK-NOT: pass_name=remove-chained-self-inverse + # CHECK: pass_name=merge-rotation + # CHECK: ] + print_jaxpr(global_wf) + + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed1"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed1"} + # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed0"} + # CHECK-NEXT: transform.yield + print_mlir(global_wf) + + # CHECK: func.func public @jit_global_wf() + # CHECK {{%.+}} = call @g_transformed1( + # CHECK {{%.+}} = call @h_transformed0( + # CHECK: func.func private @g_transformed1 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK: func.func private @h_transformed0 + # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit + # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit + global_wf() + flush_peephole_opted_mlir_to_iostream(global_wf) + + +test_pipeline_lowering_globloc_override() + + # # cancel_inverses # @@ -118,11 +379,11 @@ def h(x: float): # CHECK: transform_named_sequence # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_cancel_inverses + # CHECK: options=func-name=f_cancel_inverses0 # CHECK: pass_name=remove-chained-self-inverse # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_cancel_inverses + # CHECK: options=func-name=g_cancel_inverses1 # CHECK: pass_name=remove-chained-self-inverse # CHECK: ] # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ @@ -132,8 +393,8 @@ def h(x: float): # CHECK: module @test_cancel_inverses_tracing_and_lowering_workflow # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses1"} # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_cancel_inverses"} # CHECK-NEXT: transform.yield print_mlir(test_cancel_inverses_tracing_and_lowering_workflow, 1.1) @@ -162,14 +423,14 @@ def test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow(xx: float): # CHECK: transform_named_sequence # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_cancel_inverses + # CHECK: options=func-name=f_cancel_inverses0 # CHECK: pass_name=remove-chained-self-inverse # CHECK: ] print_jaxpr(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1) # CHECK: module @test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"} # CHECK-NEXT: transform.yield print_mlir(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1) @@ -254,7 +515,7 @@ def f(x: float): # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow0 # CHECK: {{%.+}} = call @f({{%.+}}) - # CHECK-NOT: {{%.+}} = call @f_cancel_inverses({{%.+}}) + # CHECK-NOT: {{%.+}} = call @f_cancel_inverses0({{%.+}}) # CHECK-LABEL: private @f({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -268,9 +529,9 @@ def test_cancel_inverses_keep_original_workflow0(): flush_peephole_opted_mlir_to_iostream(test_cancel_inverses_keep_original_workflow0) # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow1 - # CHECK: {{%.+}} = call @f_cancel_inverses({{%.+}}) + # CHECK: {{%.+}} = call @f_cancel_inverses0({{%.+}}) # CHECK-NOT: {{%.+}} = call @f({{%.+}}) - # CHECK-LABEL: private @f_cancel_inverses({{%.+}}) + # CHECK-LABEL: private @f_cancel_inverses0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -284,12 +545,12 @@ def test_cancel_inverses_keep_original_workflow1(): # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow2 # CHECK: {{%.+}} = call @f({{%.+}}) - # CHECK: {{%.+}} = call @f_cancel_inverses({{%.+}}) + # CHECK: {{%.+}} = call @f_cancel_inverses0({{%.+}}) # CHECK-LABEL: private @f({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit - # CHECK-LABEL: private @f_cancel_inverses({{%.+}}) + # CHECK-LABEL: private @f_cancel_inverses0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 810f254c11..40bbf0fb28 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -19,7 +19,7 @@ import pytest from catalyst import qjit -from catalyst.passes import cancel_inverses +from catalyst.passes import cancel_inverses, pipeline # pylint: disable=missing-function-docstring @@ -91,6 +91,42 @@ def g(x): assert np.allclose(workflow()[0], workflow()[1]) +@pytest.mark.parametrize("theta", [42.42]) +def test_pipeline_functionality(capfd, theta, backend): + """ + Test that the @pipeline decorator does not change functionality + when all the passes in the pipeline does not change functionality. + """ + my_pipeline = { + "cancel_inverses": {}, + "merge_rotations": {"my-option": "aloha"}, + } + + @qjit + def workflow(): + @qml.qnode(qml.device(backend, wires=2)) + def f(x): + qml.RX(x, wires=[0]) + qml.Hadamard(wires=[1]) + qml.Hadamard(wires=[1]) + return qml.expval(qml.PauliY(wires=0)) + + no_pipeline_result = f(theta) + pipeline_result = pipeline(pass_pipeline=my_pipeline)(f)(theta) + + return no_pipeline_result, pipeline_result + + res = workflow() + assert np.allclose(res[0], res[1]) + + # TODO: the boilerplate merge rotation pass prints out different messages based on + # the pass option. + # The purpose is to test the integration of pass options with pipeline decorator. + # Remove the string check when merge rotation becomes the actual merge rotation pass. + output_message = capfd.readouterr().err + assert output_message == "merge rotation pass, aloha!\n" + + ### Test bad usages of pass decorators ### def test_cancel_inverses_bad_usages(): """ @@ -101,6 +137,12 @@ def test_cancel_inverses_not_on_qnode(): def classical_func(): return 42.42 + with pytest.raises( + TypeError, + match="A QNode is expected, got the classical function", + ): + pipeline(classical_func) + with pytest.raises( TypeError, match="A QNode is expected, got the classical function", diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 96c748bd49..2e241bcd7d 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -28,5 +28,6 @@ std::unique_ptr createAdjointLoweringPass(); std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); +std::unique_ptr createMergeRotationPass(); } // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 0e720a947e..202347c150 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -91,7 +91,11 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> { // ----- Quantum circuit transformation passes begin ----- // // For example, automatic compiler peephole opts, etc. -class QuantumCircuitTransformationPass : Pass { +class QuantumCircuitTransformationPassBase { + list