Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SampleD with a non-spherical Gaussian convolution matrix (with standard base and center 0) #407

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
147 changes: 147 additions & 0 deletions src/integer/mat_z/sample/discrete_gauss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,68 @@ impl MatZ {
MatZ::sample_d(&basis, n, &center, s)
}

/// Samples a (possibly) non-spherical discrete Gaussian using
/// the standard basis and center `0`.
///
/// Parameters:
/// - `n`: specifies the range from which [`MatQ::randomized_rounding`] samples
/// - `sigma_`: specifies the positive definite Gaussian convolution matrix
/// with which the *intermediate* continuous Gaussian is sampled before
/// the randomized rounding is applied. Here `sigma_ = sqrt(sigma^2 - r^2*I)`
/// where sigma is the target convolution matrix. The root can be computed using
/// the [`MatQ::cholesky_decomposition`].
/// - `r`: specifies the rounding parameter for [`MatQ::randomized_rounding`].
///
/// Returns a lattice vector sampled according to the discrete Gaussian distribution.
///
/// # Examples
/// ```
/// use qfall_math::integer::MatZ;
/// use qfall_math::rational::{Q, MatQ};
/// use std::str::FromStr;
/// use crate::qfall_math::traits::Pow;
///
/// let convolution_matrix = MatQ::from_str("[[100,1],[1,17]]").unwrap();
/// let r = Q::from(4);
///
/// let sigma_ = convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2);
///
/// let sample = MatZ::sample_d_common_non_spherical(16, &sigma_.cholesky_decomposition(), r).unwrap();
/// ```
///
/// # Errors and Failures
/// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
/// if the `n <= 1` or `r <= 0`.
///
/// # Panics ...
/// - if `sigma_` is not a square matrix.
///
/// This function implements SampleD according to Algorithm 1. in \[2\].
/// - \[2\] Peikert, Chris.
/// "An efficient and parallel Gaussian sampler for lattices.
/// In Annual Cryptology Conference, pp. 80-97. Berlin, Heidelberg: Springer
/// Berlin Heidelberg, 2010.
/// <https://link.springer.com/chapter/10.1007/978-3-642-14623-7_5>
pub fn sample_d_common_non_spherical(
n: impl Into<Z>,
sigma_: &MatQ,
r: impl Into<Q>,
) -> Result<Self, MathError> {
assert!(sigma_.is_square());
let r = r.into();

// sample a continuous Gaussian centered around `0` in every dimension with
// gaussian parameter `1`.
let d_1 = MatQ::sample_gauss_same_center(sigma_.get_num_columns(), 1, 0, 1)?;

// compute a continuous Gaussian centered around `0` in every dimension with
// convolution matrix `b_2` (the cholesky decomposition we computed)
let x_2 = sigma_ * d_1;

// perform randomized rounding
x_2.randomized_rounding(r, n)
}

/// SampleD samples a discrete Gaussian from the lattice with a provided `basis`.
///
/// We do not check whether `basis` is actually a basis or whether `basis_gso` is
Expand Down Expand Up @@ -306,3 +368,88 @@ mod test_sample_d {
let _ = MatZ::sample_d_common(10, 1024, 1.25f32).unwrap();
}
}

#[cfg(test)]
mod test_sample_d_common_non_spherical {
use crate::{
integer::{MatZ, Z},
rational::{MatQ, Q},
traits::{GetNumRows, Pow},
};
use std::str::FromStr;

/// Checks whether `sample_d_common_non_spherical` is available for all types
/// implementing [`Into<Z>`], i.e. u8, u16, u32, u64, i8, ...
/// or [`Into<Q>`], i.e. u8, i16, f32, Z, Q, ...
/// or [`Into<MatQ>`], i.e. MatQ, MatZ
#[test]
fn availability() {
let r = Q::from(8);
let convolution_matrix = MatQ::from_str("[[100,1],[1,65]]").unwrap();
let convolution_matrix = (convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2))
.cholesky_decomposition();

let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8).unwrap();

let _ = MatZ::sample_d_common_non_spherical(16u16, &convolution_matrix, 8_u16).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16u32, &convolution_matrix, 8_u32).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16u64, &convolution_matrix, 8_u64).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16i8, &convolution_matrix, 8_i8).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16i16, &convolution_matrix, 8_i16).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16i32, &convolution_matrix, 8_i32).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16i64, &convolution_matrix, 8_i64).unwrap();
let _ = MatZ::sample_d_common_non_spherical(Z::from(16), &convolution_matrix, Q::from(8))
.unwrap();
let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, Z::from(8)).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8f32).unwrap();
let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8f64).unwrap();
}

/// Checks whether the function panics if a non positive-definite matrix is provided.
#[test]
#[should_panic]
fn no_convolution_matrix_1() {
let r = Q::from(4);
let convolution_matrix = MatQ::from_str("[[-1,1],[1,1]]").unwrap();
let convolution_matrix = (convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2))
.cholesky_decomposition();

let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, r).unwrap();
}

/// Checks whether the function panics if a non-square matrix is provided.
/// anymore
#[test]
#[should_panic]
fn not_square() {
let convolution_matrix = MatQ::from_str("[[100,1,1],[1,64,2]]").unwrap();

let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8).unwrap();
}

/// Checks whether the function returns an error if `n` or `r` is too small.
#[test]
fn too_small_parameters() {
let convolution_matrix = MatQ::from_str("[[100, 1],[1, 65]]").unwrap();

assert!(MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 0).is_err());
assert!(MatZ::sample_d_common_non_spherical(16, &convolution_matrix, -1).is_err());
assert!(MatZ::sample_d_common_non_spherical(1, &convolution_matrix, 8).is_err());
assert!(MatZ::sample_d_common_non_spherical(-1, &convolution_matrix, 8).is_err());
}

/// Checks whether the dimension of the output matches the provided convolution matrix
#[test]
fn correct_dimensions() {
let convolution_matrix_1 = MatQ::from_str("[[100,1],[1,65]]").unwrap();
let convolution_matrix_2 = MatQ::from_str("[[100,1,0],[1,65,0],[0,0,10000]]").unwrap();

let sample_1 = MatZ::sample_d_common_non_spherical(16, &convolution_matrix_1, 8).unwrap();
let sample_2 = MatZ::sample_d_common_non_spherical(16, &convolution_matrix_2, 8).unwrap();

assert_eq!(2, sample_1.get_num_rows());
assert!(sample_1.is_column_vector());
assert_eq!(3, sample_2.get_num_rows());
assert!(sample_2.is_column_vector());
}
}
Loading