Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move utility functions _inverse_pattern and _get_ordered_swap to Rust #12327

Merged
merged 14 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod isometry;
pub mod nlayout;
pub mod optimize_1q_gates;
pub mod pauli_exp_val;
pub mod permutation;
pub mod results;
pub mod sabre;
pub mod sampled_exp_val;
Expand Down
120 changes: 120 additions & 0 deletions crates/accelerate/src/permutation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use ndarray::{Array1, ArrayView1};
use numpy::PyArrayLike1;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::vec::Vec;

fn validate_permutation(pattern: &ArrayView1<i64>) -> PyResult<()> {
let n = pattern.len();
let mut seen: Vec<bool> = vec![false; n];

for &x in pattern {
if x < 0 {
return Err(PyValueError::new_err(
"Invalid permutation: input contains a negative number.",
));
}

if x as usize >= n {
return Err(PyValueError::new_err(format!(
"Invalid permutation: input has length {} and contains {}.",
n, x
)));
}

if seen[x as usize] {
return Err(PyValueError::new_err(format!(
"Invalid permutation: input contains {} more than once.",
x
)));
}

seen[x as usize] = true;
}

Ok(())
}

fn invert(pattern: &ArrayView1<i64>) -> Array1<usize> {
let mut inverse: Array1<usize> = Array1::zeros(pattern.len());
pattern.iter().enumerate().for_each(|(ii, &jj)| {
inverse[jj as usize] = ii;
});
inverse
}

fn get_ordered_swap(pattern: &ArrayView1<i64>) -> Vec<(i64, i64)> {
let mut permutation: Vec<usize> = pattern.iter().map(|&x| x as usize).collect();
let mut index_map = invert(pattern);

let n = permutation.len();
let mut swaps: Vec<(i64, i64)> = Vec::with_capacity(n);
for ii in 0..n {
let val = permutation[ii];
if val == ii {
continue;
}
let jj = index_map[ii];
swaps.push((ii as i64, jj as i64));
(permutation[ii], permutation[jj]) = (permutation[jj], permutation[ii]);
index_map[val] = jj;
index_map[ii] = ii;
}

swaps[..].reverse();
swaps
}

/// Checks whether an array of size N is a permutation of 0, 1, ..., N - 1.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _validate_permutation(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
validate_permutation(&view)?;
Cryoris marked this conversation as resolved.
Show resolved Hide resolved
Ok(py.None())
}

/// Finds inverse of a permutation pattern.
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _inverse_pattern(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = pattern.as_array();
let inverse_i64: Vec<i64> = invert(&view).iter().map(|&x| x as i64).collect();
Ok(inverse_i64.to_object(py))
}

/// Sorts the input permutation by iterating through the permutation list
/// and putting each element to its correct position via a SWAP (if it's not
/// at the correct position already). If ``n`` is the length of the input
/// permutation, this requires at most ``n`` SWAPs.
///
/// More precisely, if the input permutation is a cycle of length ``m``,
/// then this creates a quantum circuit with ``m-1`` SWAPs (and of depth ``m-1``);
/// if the input permutation consists of several disjoint cycles, then each cycle
/// is essentially treated independently.
#[pyfunction]
#[pyo3(signature = (permutation_in))]
fn _get_ordered_swap(py: Python, permutation_in: PyArrayLike1<i64>) -> PyResult<PyObject> {
let view = permutation_in.as_array();
Ok(get_ordered_swap(&view).to_object(py))
}

#[pymodule]
pub fn permutation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(_validate_permutation, m)?)?;
m.add_function(wrap_pyfunction!(_inverse_pattern, m)?)?;
m.add_function(wrap_pyfunction!(_get_ordered_swap, m)?)?;
Ok(())
}
9 changes: 5 additions & 4 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use qiskit_accelerate::{
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val,
sparse_pauli_op::sparse_pauli_op, stochastic_swap::stochastic_swap,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
pauli_exp_val::pauli_expval, permutation::permutation, results::results, sabre::sabre,
sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
stochastic_swap::stochastic_swap, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
};

#[pymodule]
Expand All @@ -36,6 +36,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(nlayout))?;
m.add_wrapped(wrap_pymodule!(optimize_1q_gates))?;
m.add_wrapped(wrap_pymodule!(pauli_expval))?;
m.add_wrapped(wrap_pymodule!(permutation))?;
m.add_wrapped(wrap_pymodule!(results))?;
m.add_wrapped(wrap_pymodule!(sabre))?;
m.add_wrapped(wrap_pymodule!(sampled_exp_val))?;
Expand Down
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
sys.modules["qiskit._accelerate.stochastic_swap"] = qiskit._accelerate.stochastic_swap
sys.modules["qiskit._accelerate.two_qubit_decompose"] = qiskit._accelerate.two_qubit_decompose
sys.modules["qiskit._accelerate.vf2_layout"] = qiskit._accelerate.vf2_layout
sys.modules["qiskit._accelerate.permutation"] = qiskit._accelerate.permutation

from qiskit.exceptions import QiskitError, MissingOptionalLibraryError

Expand Down
36 changes: 6 additions & 30 deletions qiskit/synthesis/permutation/permutation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,12 @@

"""Utility functions for handling permutations."""


def _get_ordered_swap(permutation_in):
"""Sorts the input permutation by iterating through the permutation list
and putting each element to its correct position via a SWAP (if it's not
at the correct position already). If ``n`` is the length of the input
permutation, this requires at most ``n`` SWAPs.

More precisely, if the input permutation is a cycle of length ``m``,
then this creates a quantum circuit with ``m-1`` SWAPs (and of depth ``m-1``);
if the input permutation consists of several disjoint cycles, then each cycle
is essentially treated independently.
"""
permutation = list(permutation_in[:])
swap_list = []
index_map = _inverse_pattern(permutation_in)
for i, val in enumerate(permutation):
if val != i:
j = index_map[i]
swap_list.append((i, j))
permutation[i], permutation[j] = permutation[j], permutation[i]
index_map[val] = j
index_map[i] = i
swap_list.reverse()
return swap_list


def _inverse_pattern(pattern):
"""Finds inverse of a permutation pattern."""
b_map = {pos: idx for idx, pos in enumerate(pattern)}
return [b_map[pos] for pos in range(len(pattern))]
# pylint: disable=unused-import
from qiskit._accelerate.permutation import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)


def _pattern_to_cycles(pattern):
Expand Down
48 changes: 46 additions & 2 deletions test/python/synthesis/test_permutation_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,31 @@
synth_permutation_basic,
synth_permutation_reverse_lnn_kms,
)
from qiskit.synthesis.permutation.permutation_utils import _get_ordered_swap
from qiskit.synthesis.permutation.permutation_utils import (
_inverse_pattern,
_get_ordered_swap,
_validate_permutation,
)
from test import QiskitTestCase # pylint: disable=wrong-import-order


@ddt
class TestPermutationSynthesis(QiskitTestCase):
"""Test the permutation synthesis functions."""

@data(4, 5, 10, 15, 20)
def test_inverse_pattern(self, width):
"""Test _inverse_pattern function produces correct index map."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)
inverse = _inverse_pattern(pattern)
for ii, jj in enumerate(pattern):
self.assertTrue(inverse[jj] == ii)

@data(4, 5, 10, 15, 20)
def test_get_ordered_swap(self, width):
"""Test get_ordered_swap function produces correct swap list."""
"""Test _get_ordered_swap function produces correct swap list."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)
Expand All @@ -46,6 +60,36 @@ def test_get_ordered_swap(self, width):
self.assertTrue(np.array_equal(pattern, output))
self.assertLess(len(swap_list), width)

@data(10, 20)
def test_invalid_permutations(self, width):
"""Check that _validate_permutation raises exceptions when the
input is not a permutation."""
np.random.seed(1)
for _ in range(5):
pattern = np.random.permutation(width)

pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = -1
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_out_of_range)
self.assertIn("input contains a negative number", str(exc.exception))

pattern_out_of_range = np.copy(pattern)
pattern_out_of_range[0] = width
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_out_of_range)
self.assertIn(
"input has length {0} and contains {0}".format(width), str(exc.exception)
)

pattern_duplicate = np.copy(pattern)
pattern_duplicate[-1] = pattern[0]
with self.assertRaises(ValueError) as exc:
_validate_permutation(pattern_duplicate)
self.assertIn(
"input contains {} more than once".format(pattern[0]), str(exc.exception)
)

@data(4, 5, 10, 15, 20)
def test_synth_permutation_basic(self, width):
"""Test synth_permutation_basic function produces the correct
Expand Down
Loading