diff --git a/src/integer/mat_z/sample/discrete_gauss.rs b/src/integer/mat_z/sample/discrete_gauss.rs index 69221357..a7fbbc16 100644 --- a/src/integer/mat_z/sample/discrete_gauss.rs +++ b/src/integer/mat_z/sample/discrete_gauss.rs @@ -144,6 +144,68 @@ impl MatZ { MatZ::sample_d(&basis, n, ¢er, s) } + /// Samples a (possibly) non-spherical discrete Gaussian using + /// the standard basis and center `0`. + /// + /// Parameters: + /// - `n`: specifies the range from which [`MatQ::randomized_rounding`] samples + /// - `sigma_`: specifies the positive definite Gaussian convolution matrix + /// with which the *intermediate* continuous Gaussian is sampled before + /// the randomized rounding is applied. Here `sigma_ = sqrt(sigma^2 - r^2*I)` + /// where sigma is the target convolution matrix. The root can be computed using + /// the [`MatQ::cholesky_decomposition`]. + /// - `r`: specifies the rounding parameter for [`MatQ::randomized_rounding`]. + /// + /// Returns a lattice vector sampled according to the discrete Gaussian distribution. + /// + /// # Examples + /// ``` + /// use qfall_math::integer::MatZ; + /// use qfall_math::rational::{Q, MatQ}; + /// use std::str::FromStr; + /// use crate::qfall_math::traits::Pow; + /// + /// let convolution_matrix = MatQ::from_str("[[100,1],[1,17]]").unwrap(); + /// let r = Q::from(4); + /// + /// let sigma_ = convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2); + /// + /// let sample = MatZ::sample_d_common_non_spherical(16, &sigma_.cholesky_decomposition(), r).unwrap(); + /// ``` + /// + /// # Errors and Failures + /// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput) + /// if the `n <= 1` or `r <= 0`. + /// + /// # Panics ... + /// - if `sigma_` is not a square matrix. + /// + /// This function implements SampleD according to Algorithm 1. in \[2\]. + /// - \[2\] Peikert, Chris. + /// "An efficient and parallel Gaussian sampler for lattices. + /// In Annual Cryptology Conference, pp. 80-97. Berlin, Heidelberg: Springer + /// Berlin Heidelberg, 2010. + /// + pub fn sample_d_common_non_spherical( + n: impl Into, + sigma_: &MatQ, + r: impl Into, + ) -> Result { + assert!(sigma_.is_square()); + let r = r.into(); + + // sample a continuous Gaussian centered around `0` in every dimension with + // gaussian parameter `1`. + let d_1 = MatQ::sample_gauss_same_center(sigma_.get_num_columns(), 1, 0, 1)?; + + // compute a continuous Gaussian centered around `0` in every dimension with + // convolution matrix `b_2` (the cholesky decomposition we computed) + let x_2 = sigma_ * d_1; + + // perform randomized rounding + x_2.randomized_rounding(r, n) + } + /// SampleD samples a discrete Gaussian from the lattice with a provided `basis`. /// /// We do not check whether `basis` is actually a basis or whether `basis_gso` is @@ -306,3 +368,88 @@ mod test_sample_d { let _ = MatZ::sample_d_common(10, 1024, 1.25f32).unwrap(); } } + +#[cfg(test)] +mod test_sample_d_common_non_spherical { + use crate::{ + integer::{MatZ, Z}, + rational::{MatQ, Q}, + traits::{GetNumRows, Pow}, + }; + use std::str::FromStr; + + /// Checks whether `sample_d_common_non_spherical` is available for all types + /// implementing [`Into`], i.e. u8, u16, u32, u64, i8, ... + /// or [`Into`], i.e. u8, i16, f32, Z, Q, ... + /// or [`Into`], i.e. MatQ, MatZ + #[test] + fn availability() { + let r = Q::from(8); + let convolution_matrix = MatQ::from_str("[[100,1],[1,65]]").unwrap(); + let convolution_matrix = (convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2)) + .cholesky_decomposition(); + + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8).unwrap(); + + let _ = MatZ::sample_d_common_non_spherical(16u16, &convolution_matrix, 8_u16).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16u32, &convolution_matrix, 8_u32).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16u64, &convolution_matrix, 8_u64).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16i8, &convolution_matrix, 8_i8).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16i16, &convolution_matrix, 8_i16).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16i32, &convolution_matrix, 8_i32).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16i64, &convolution_matrix, 8_i64).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(Z::from(16), &convolution_matrix, Q::from(8)) + .unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, Z::from(8)).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8f32).unwrap(); + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8f64).unwrap(); + } + + /// Checks whether the function panics if a non positive-definite matrix is provided. + #[test] + #[should_panic] + fn no_convolution_matrix_1() { + let r = Q::from(4); + let convolution_matrix = MatQ::from_str("[[-1,1],[1,1]]").unwrap(); + let convolution_matrix = (convolution_matrix - r.pow(2).unwrap() * MatQ::identity(2, 2)) + .cholesky_decomposition(); + + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, r).unwrap(); + } + + /// Checks whether the function panics if a non-square matrix is provided. + /// anymore + #[test] + #[should_panic] + fn not_square() { + let convolution_matrix = MatQ::from_str("[[100,1,1],[1,64,2]]").unwrap(); + + let _ = MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 8).unwrap(); + } + + /// Checks whether the function returns an error if `n` or `r` is too small. + #[test] + fn too_small_parameters() { + let convolution_matrix = MatQ::from_str("[[100, 1],[1, 65]]").unwrap(); + + assert!(MatZ::sample_d_common_non_spherical(16, &convolution_matrix, 0).is_err()); + assert!(MatZ::sample_d_common_non_spherical(16, &convolution_matrix, -1).is_err()); + assert!(MatZ::sample_d_common_non_spherical(1, &convolution_matrix, 8).is_err()); + assert!(MatZ::sample_d_common_non_spherical(-1, &convolution_matrix, 8).is_err()); + } + + /// Checks whether the dimension of the output matches the provided convolution matrix + #[test] + fn correct_dimensions() { + let convolution_matrix_1 = MatQ::from_str("[[100,1],[1,65]]").unwrap(); + let convolution_matrix_2 = MatQ::from_str("[[100,1,0],[1,65,0],[0,0,10000]]").unwrap(); + + let sample_1 = MatZ::sample_d_common_non_spherical(16, &convolution_matrix_1, 8).unwrap(); + let sample_2 = MatZ::sample_d_common_non_spherical(16, &convolution_matrix_2, 8).unwrap(); + + assert_eq!(2, sample_1.get_num_rows()); + assert!(sample_1.is_column_vector()); + assert_eq!(3, sample_2.get_num_rows()); + assert!(sample_2.is_column_vector()); + } +}