Skip to content

Commit

Permalink
Support SwitchCaseOp that will be released in qiskit-terra 0.24.0 (#…
Browse files Browse the repository at this point in the history
…1778)

* Add support of `SwitchCaseOp`

Qiskit-Terra added a new instruction of control-flow `SwitchCaseOp`.
This commit enables Aer to simulate `SwitchCaseOp` by converting
its conditions and bodies with `AerMark` and `AerJump`.

* add switch_case temporarily until qiskit-terra 0.24.0 is released.
* simplify switch compilation
* use SwitchCaseOp of terra
* fix lint errors
  • Loading branch information
hhorii authored May 16, 2023
1 parent 334c894 commit 11b6f44
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 3 deletions.
96 changes: 94 additions & 2 deletions qiskit_aer/backends/aer_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Compier to convert Qiskit control-flow to Aer backend.
"""

import collections
import itertools
from copy import copy
from typing import List
Expand All @@ -21,7 +22,15 @@
from qiskit.extensions import Initialize
from qiskit.providers.options import Options
from qiskit.pulse import Schedule, ScheduleBlock
from qiskit.circuit.controlflow import WhileLoopOp, ForLoopOp, IfElseOp, BreakLoopOp, ContinueLoopOp
from qiskit.circuit.controlflow import (
WhileLoopOp,
ForLoopOp,
IfElseOp,
BreakLoopOp,
ContinueLoopOp,
SwitchCaseOp,
CASE_DEFAULT,
)
from qiskit.compiler import transpile
from qiskit.qobj import QobjExperimentHeader
from qiskit_aer.aererror import AerError
Expand Down Expand Up @@ -113,7 +122,14 @@ def _is_dynamic(circuit, optype=None):
if not isinstance(circuit, QuantumCircuit):
return False

controlflow_types = (WhileLoopOp, ForLoopOp, IfElseOp, BreakLoopOp, ContinueLoopOp)
controlflow_types = (
WhileLoopOp,
ForLoopOp,
IfElseOp,
BreakLoopOp,
ContinueLoopOp,
SwitchCaseOp,
)

# Check via optypes
if isinstance(optype, set):
Expand Down Expand Up @@ -158,6 +174,10 @@ def _inline_circuit(self, circ, continue_label, break_label, bit_map=None):
ret.barrier()
self._inline_if_else_op(instruction, continue_label, break_label, ret, bit_map)
ret.barrier()
elif isinstance(instruction.operation, SwitchCaseOp):
ret.barrier()
self._inline_switch_case_op(instruction, continue_label, break_label, ret, bit_map)
ret.barrier()
elif isinstance(instruction.operation, BreakLoopOp):
ret._append(
AerJump(break_label, ret.num_qubits, ret.num_clbits), ret.qubits, ret.clbits
Expand Down Expand Up @@ -328,6 +348,78 @@ def _inline_if_else_op(self, instruction, continue_label, break_label, parent, b

parent.append(AerMark(if_end_label, len(qargs), len(mark_cargs)), qargs, mark_cargs)

def _inline_switch_case_op(self, instruction, continue_label, break_label, parent, bit_map):
"""inline switch cases with jump and mark instructions"""
cases = instruction.operation.cases_specifier()

self._last_flow_id += 1
switch_id = self._last_flow_id
switch_name = f"switch_{switch_id}"

qargs = [bit_map[q] for q in instruction.qubits]
cargs = [bit_map[c] for c in instruction.clbits]
mark_cargs = (
set(cargs + [bit_map[instruction.operation.target]])
if isinstance(instruction.operation.target, Clbit)
else set(cargs + [bit_map[c] for c in instruction.operation.target])
) - set(instruction.clbits)

switch_end_label = f"{switch_name}_end"
case_default_label = None
CaseData = collections.namedtuple("CaseData", ["label", "args_list", "bit_map", "body"])
case_data_list = []
for i, case in enumerate(cases):
if case_default_label is not None:
raise AerError("cases after the default are unreachable")

case_data = CaseData(
label=f"{switch_name}_{i}",
args_list=[
self._convert_c_if_args((instruction.operation.target, switch_val), bit_map)
if switch_val != CASE_DEFAULT
else []
for switch_val in case[0]
],
bit_map={
inner: bit_map[outer]
for inner, outer in itertools.chain(
zip(case[1].qubits, instruction.qubits),
zip(case[1].clbits, instruction.clbits),
)
},
body=case[1],
)
case_data_list.append(case_data)
if CASE_DEFAULT in case[0]:
case_default_label = case_data.label

if case_default_label is None:
case_default_label = switch_end_label

for case_data in case_data_list:
for case_args in case_data.args_list:
if len(case_args) > 0:
parent.append(
AerJump(case_data.label, len(qargs), len(mark_cargs)).c_if(*case_args),
qargs,
mark_cargs,
)

parent.append(AerJump(case_default_label, len(qargs), len(mark_cargs)), qargs, mark_cargs)

for case_data in case_data_list:
parent.append(AerMark(case_data.label, len(qargs), len(mark_cargs)), qargs, mark_cargs)
parent.append(
self._inline_circuit(
case_data.body, continue_label, break_label, case_data.bit_map
),
qargs,
cargs,
)
parent.append(AerJump(switch_end_label, len(qargs), len(mark_cargs)), qargs, mark_cargs)

parent.append(AerMark(switch_end_label, len(qargs), len(mark_cargs)), qargs, mark_cargs)


def compile_circuit(circuits, basis_gates=None, optypes=None):
"""
Expand Down
6 changes: 6 additions & 0 deletions releasenotes/notes/support_switch-41603d87cb8358fb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Add support of :class:`qiskit.circuit.controlflow.SwitchCaseOp` that was introduced
in qiskit-terra 0.24.0. The instruction is converted to multiple instructions of
:class:`~.AerMark` and :class:`~.AerJump` with `c_if` conditions.
235 changes: 234 additions & 1 deletion test/terra/backends/aer_simulator/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from test.terra.backends.simulator_test_case import SimulatorTestCase, supported_methods
from qiskit_aer import AerSimulator
from qiskit import QuantumCircuit, transpile
from qiskit.circuit import Parameter, Qubit, QuantumRegister, ClassicalRegister
from qiskit.circuit import Parameter, Qubit, Clbit, QuantumRegister, ClassicalRegister
from qiskit.circuit.controlflow import *
from qiskit_aer.library.default_qubits import default_qubits
from qiskit_aer.library.control_flow_instructions import AerMark, AerJump
Expand Down Expand Up @@ -630,3 +630,236 @@ def test_transpile_break_and_continue_loop(self, method):
transpiled = transpile(qc, backend)
result = backend.run(transpiled, method=method, shots=100).result()
self.assertEqual(result.get_counts(), {"1": 100})

@data("statevector", "density_matrix", "matrix_product_state")
def test_switch_clbit(self, method):
"""Test that a switch statement can be constructed with a bit as a condition."""

backend = self.backend(method=method)

qubit = Qubit()
clbit = Clbit()
case0 = QuantumCircuit([qubit, clbit])
case0.x(0)
case1 = QuantumCircuit([qubit, clbit])
case1.h(0)

op = SwitchCaseOp(clbit, [(False, case0), (True, case1)])

qc0 = QuantumCircuit([qubit, clbit])
qc0.measure(qubit, clbit)
qc0.append(op, [qubit], [clbit])
qc0.measure_all()

qc0_expected = QuantumCircuit([qubit, clbit])
qc0_expected.measure(qubit, clbit)
qc0_expected.append(case0, [qubit], [clbit])
qc0_expected.measure_all()
qc0_expected = transpile(qc0_expected, backend)

ret0 = backend.run(qc0, shots=10000, seed_simulator=1).result()
ret0_expected = backend.run(qc0_expected, shots=10000, seed_simulator=1).result()
self.assertSuccess(ret0)
self.assertEqual(ret0.get_counts(), ret0_expected.get_counts())

qc1 = QuantumCircuit([qubit, clbit])
qc1.x(0)
qc1.measure(qubit, clbit)
qc1.append(op, [qubit], [clbit])
qc1.measure_all()

qc1_expected = QuantumCircuit([qubit, clbit])
qc1_expected.x(0)
qc1_expected.measure(qubit, clbit)
qc1_expected.append(case1, [qubit], [clbit])
qc1_expected.measure_all()
qc1_expected = transpile(qc1_expected, backend)

ret1 = backend.run(qc1, shots=10000, seed_simulator=1).result()
ret1_expected = backend.run(qc1_expected, shots=10000, seed_simulator=1).result()
self.assertSuccess(ret1)
self.assertEqual(ret1.get_counts(), ret1_expected.get_counts())

@data("statevector", "density_matrix", "matrix_product_state")
def test_switch_register(self, method):
"""Test that a switch statement can be constructed with a register as a condition."""

backend = self.backend(method=method, seed_simulator=1)

qubit0 = Qubit()
qubit1 = Qubit()
qubit2 = Qubit()
creg = ClassicalRegister(2)
case1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case1.x(0)
case2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case2.x(1)
case3 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case3.x(2)

op = SwitchCaseOp(creg, [(0, case1), (1, case2), (2, case3)])

qc0 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc0.measure(0, creg[0])
qc0.append(op, [qubit0, qubit1, qubit2], creg)
qc0.measure_all()

ret0 = backend.run(qc0, shots=100).result()
self.assertSuccess(ret0)
self.assertEqual(ret0.get_counts(), {"001 00": 100})

qc1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc1.x(0)
qc1.measure(0, creg[0])
qc1.append(op, [qubit0, qubit1, qubit2], creg)
qc1.measure_all()

ret1 = backend.run(qc1, shots=100).result()
self.assertSuccess(ret1)
self.assertEqual(ret1.get_counts(), {"011 01": 100})

qc2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc2.x(1)
qc2.measure(0, creg[0])
qc2.measure(1, creg[1])
qc2.append(op, [qubit0, qubit1, qubit2], creg)
qc2.measure_all()

ret2 = backend.run(qc2, shots=100).result()
self.assertSuccess(ret2)
self.assertEqual(ret2.get_counts(), {"110 10": 100})

qc3 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc3.x(0)
qc3.x(1)
qc3.measure(0, creg[0])
qc3.measure(1, creg[1])
qc3.append(op, [qubit0, qubit1, qubit2], creg)
qc3.measure_all()

ret3 = backend.run(qc3, shots=100).result()
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"011 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
def test_switch_with_default(self, method):
"""Test that a switch statement can be constructed with a default case at the end."""

backend = self.backend(method=method, seed_simulator=1)

qubit0 = Qubit()
qubit1 = Qubit()
qubit2 = Qubit()
creg = ClassicalRegister(2)
case1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case1.x(0)
case2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case2.x(1)
case3 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case3.x(2)

op = SwitchCaseOp(creg, [(0, case1), (1, case2), (CASE_DEFAULT, case3)])

qc0 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc0.measure(0, creg[0])
qc0.append(op, [qubit0, qubit1, qubit2], creg)
qc0.measure_all()

ret0 = backend.run(qc0, shots=100).result()
self.assertSuccess(ret0)
self.assertEqual(ret0.get_counts(), {"001 00": 100})

qc1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc1.x(0)
qc1.measure(0, creg[0])
qc1.append(op, [qubit0, qubit1, qubit2], creg)
qc1.measure_all()

ret1 = backend.run(qc1, shots=100).result()
self.assertSuccess(ret1)
self.assertEqual(ret1.get_counts(), {"011 01": 100})

qc2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc2.x(1)
qc2.measure(0, creg[0])
qc2.measure(1, creg[1])
qc2.append(op, [qubit0, qubit1, qubit2], creg)
qc2.measure_all()

ret2 = backend.run(qc2, shots=100).result()
self.assertSuccess(ret2)
self.assertEqual(ret2.get_counts(), {"110 10": 100})

qc3 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc3.x(0)
qc3.x(1)
qc3.measure(0, creg[0])
qc3.measure(1, creg[1])
qc3.append(op, [qubit0, qubit1, qubit2], creg)
qc3.measure_all()

ret3 = backend.run(qc3, shots=100).result()
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"111 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
def test_switch_multiple_cases_to_same_block(self, method):
"""Test that it is possible to add multiple cases that apply to the same block, if they are
given as a compound value. This is an allowed special case of block fall-through."""

backend = self.backend(method=method, seed_simulator=1)

qubit0 = Qubit()
qubit1 = Qubit()
qubit2 = Qubit()
creg = ClassicalRegister(2)
case1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case1.x(0)
case2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
case2.x(1)

creg = ClassicalRegister(2)

op = SwitchCaseOp(creg, [(0, case1), ((1, 2), case2)])

qc0 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc0.measure(0, creg[0])
qc0.append(op, [qubit0, qubit1, qubit2], creg)
qc0.measure_all()

ret0 = backend.run(qc0, shots=100).result()
self.assertSuccess(ret0)
self.assertEqual(ret0.get_counts(), {"001 00": 100})

qc1 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc1.x(0)
qc1.measure(0, creg[0])
qc1.append(op, [qubit0, qubit1, qubit2], creg)
qc1.measure_all()

ret1 = backend.run(qc1, shots=100).result()
self.assertSuccess(ret1)
self.assertEqual(ret1.get_counts(), {"011 01": 100})

qc2 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc2.x(1)
qc2.measure(0, creg[0])
qc2.measure(1, creg[1])
qc2.append(op, [qubit0, qubit1, qubit2], creg)
qc2.measure_all()

ret2 = backend.run(qc2, shots=100).result()
self.assertSuccess(ret2)
self.assertEqual(ret2.get_counts(), {"000 10": 100})

qc3 = QuantumCircuit([qubit0, qubit1, qubit2], creg)
qc3.x(0)
qc3.x(1)
qc3.measure(0, creg[0])
qc3.measure(1, creg[1])
qc3.append(op, [qubit0, qubit1, qubit2], creg)
qc3.measure_all()

ret3 = backend.run(qc3, shots=100).result()
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"011 11": 100})

0 comments on commit 11b6f44

Please sign in to comment.