From 2427a275048e598c6d651cce8348a4c55148f235 Mon Sep 17 00:00:00 2001 From: Ethan-000 Date: Thu, 16 Feb 2023 21:34:51 +0000 Subject: [PATCH] chore!: refactor ToRadix to ToRadixLe and ToRadixBe (#58) * refactor ToRadix to ToRadixLe and ToRadixBe * change same code to one function * could be cleaner * change function name sightly * change to one opcode * change let a_dec * small refactor * remove redundant if * remove LE postfix * add `insert_value` method * refactor `insert_value` to usse `.insert` method * refactor `to_radix_outcome` * endianess -> endianness * fix big endian padding * small fix * delete function and add an error * remove semicolons * Merge conflict * fix clippy --------- Co-authored-by: Kevaundray Wedderburn --- acir/src/circuit/directives.rs | 32 +++++++++++++--- acir/src/circuit/opcodes.rs | 10 ++++- acvm/src/lib.rs | 2 + acvm/src/pwg.rs | 23 ++++++++++++ acvm/src/pwg/directives.rs | 69 ++++++++++++++++++++++++++-------- stdlib/src/fallback.rs | 1 + 6 files changed, 114 insertions(+), 23 deletions(-) diff --git a/acir/src/circuit/directives.rs b/acir/src/circuit/directives.rs index f3f5e6aca..f9550f88f 100644 --- a/acir/src/circuit/directives.rs +++ b/acir/src/circuit/directives.rs @@ -42,11 +42,12 @@ pub enum Directive { bit_size: u32, }, - //decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix + //decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in either little endian or big endian form ToRadix { a: Expression, b: Vec, radix: u32, + is_little_endian: bool, }, // Sort directive, using a sorting network @@ -122,13 +123,19 @@ impl Directive { write_u32(&mut writer, r.witness_index())?; write_u32(&mut writer, *bit_size)?; } - Directive::ToRadix { a, b, radix } => { + Directive::ToRadix { + a, + b, + radix, + is_little_endian, + } => { a.write(&mut writer)?; write_u32(&mut writer, b.len() as u32)?; for bit in b { write_u32(&mut writer, bit.witness_index())?; } write_u32(&mut writer, *radix)?; + write_u32(&mut writer, *is_little_endian as u32)?; } Directive::PermutationSort { inputs: a, @@ -222,8 +229,14 @@ impl Directive { } let radix = read_u32(&mut reader)?; + let is_little_endian = read_u32(&mut reader)?; - Ok(Directive::ToRadix { a, b, radix }) + Ok(Directive::ToRadix { + a, + b, + radix, + is_little_endian: is_little_endian == 1, + }) } 6 => { let tuple = read_u32(&mut reader)?; @@ -314,10 +327,18 @@ fn serialization_roundtrip() { bit_size: 32, }; - let to_radix = Directive::ToRadix { + let to_radix_le = Directive::ToRadix { + a: Expression::default(), + b: vec![Witness(1u32), Witness(2u32), Witness(3u32), Witness(4u32)], + radix: 4, + is_little_endian: true, + }; + + let to_radix_be = Directive::ToRadix { a: Expression::default(), b: vec![Witness(1u32), Witness(2u32), Witness(3u32), Witness(4u32)], radix: 4, + is_little_endian: false, }; let directives = vec![ @@ -326,7 +347,8 @@ fn serialization_roundtrip() { quotient_predicate, truncate, odd_range, - to_radix, + to_radix_le, + to_radix_be, ]; for directive in directives { diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index 8d6794387..674307909 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -152,16 +152,22 @@ impl std::fmt::Display for Opcode { ) } Opcode::BlackBoxFuncCall(g) => write!(f, "{g}"), - Opcode::Directive(Directive::ToRadix { a, b, radix: _ }) => { + Opcode::Directive(Directive::ToRadix { + a, + b, + radix: _, + is_little_endian, + }) => { write!(f, "DIR::TORADIX ")?; write!( f, // TODO (Note): this assumes that the decomposed bits have contiguous witness indices // This should be the case, however, we can also have a function which checks this - "(_{}, [_{}..._{}])", + "(_{}, [_{}..._{}], endianness: {})", a, b.first().unwrap().witness_index(), b.last().unwrap().witness_index(), + if *is_little_endian { "little" } else { "big" } ) } Opcode::Directive(Directive::PermutationSort { diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index e82fa9d4b..4c3ed54ff 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -31,6 +31,8 @@ pub enum OpcodeNotSolvable { MissingAssignment(u32), #[error("expression has too many unknowns {0}")] ExpressionHasTooManyUnknowns(Expression), + #[error("compiler error: unreachable code")] + UnreachableCode, } #[derive(PartialEq, Eq, Debug, Error)] diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index a650e7b89..e31c34442 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -64,3 +64,26 @@ pub fn get_value( Ok(result) } + +// Inserts `value` into the initial witness map +// under the key of `witness`. +// Returns an error, if there was already a value in the map +// which does not match the value that one is about to insert +fn insert_value( + witness: &Witness, + value_to_insert: FieldElement, + initial_witness: &mut BTreeMap, +) -> Result<(), OpcodeResolutionError> { + let optional_old_value = initial_witness.insert(*witness, value_to_insert); + + let old_value = match optional_old_value { + Some(old_value) => old_value, + None => return Ok(()), + }; + + if old_value != value_to_insert { + return Err(OpcodeResolutionError::UnsatisfiedConstrain); + } + + Ok(()) +} diff --git a/acvm/src/pwg/directives.rs b/acvm/src/pwg/directives.rs index 95f0b7801..d2b9a33e9 100644 --- a/acvm/src/pwg/directives.rs +++ b/acvm/src/pwg/directives.rs @@ -8,9 +8,9 @@ use acir::{ use num_bigint::BigUint; use num_traits::{One, Zero}; -use crate::OpcodeResolutionError; +use crate::{OpcodeNotSolvable, OpcodeResolutionError}; -use super::{get_value, sorting::route, witness_to_value}; +use super::{get_value, insert_value, sorting::route, witness_to_value}; pub fn solve_directives( initial_witness: &mut BTreeMap, @@ -84,21 +84,58 @@ pub fn solve_directives( Ok(()) } - Directive::ToRadix { a, b, radix } => { - let val_a = get_value(a, initial_witness)?; + Directive::ToRadix { + a, + b, + radix, + is_little_endian, + } => { + let value_a = get_value(a, initial_witness)?; - let a_big = BigUint::from_bytes_be(&val_a.to_be_bytes()); - let a_dec = a_big.to_radix_le(*radix); - if b.len() < a_dec.len() { - return Err(OpcodeResolutionError::UnsatisfiedConstrain); - } - for i in 0..b.len() { - let v = if i < a_dec.len() { - FieldElement::from_be_bytes_reduce(&[a_dec[i]]) - } else { - FieldElement::zero() - }; - insert_witness(b[i], v, initial_witness)?; + let big_integer = BigUint::from_bytes_be(&value_a.to_be_bytes()); + + if *is_little_endian { + // Decompose the integer into its radix digits in little endian form. + let decomposed_integer = big_integer.to_radix_le(*radix); + + if b.len() < decomposed_integer.len() { + return Err(OpcodeResolutionError::UnsatisfiedConstrain); + } + + for (i, witness) in b.iter().enumerate() { + // Fetch the `i'th` digit from the decomposed integer list + // and convert it to a field element. + // If it is not available, which can happen when the decomposed integer + // list is shorter than the witness list, we return 0. + let value = match decomposed_integer.get(i) { + Some(digit) => FieldElement::from_be_bytes_reduce(&[*digit]), + None => FieldElement::zero(), + }; + + insert_value(witness, value, initial_witness)? + } + } else { + // Decompose the integer into its radix digits in big endian form. + let decomposed_integer = big_integer.to_radix_be(*radix); + + // if it is big endian and the decompoased integer list is shorter + // than the witness list, pad the extra part with 0 first then + // add the decompsed interger list to the witness list. + let padding_len = b.len() - decomposed_integer.len(); + let mut value = FieldElement::zero(); + for (i, witness) in b.iter().enumerate() { + if i >= padding_len { + value = match decomposed_integer.get(i - padding_len) { + Some(digit) => FieldElement::from_be_bytes_reduce(&[*digit]), + None => { + return Err(OpcodeResolutionError::OpcodeNotSolvable( + OpcodeNotSolvable::UnreachableCode, + )) + } + }; + } + insert_value(witness, value, initial_witness)? + } } Ok(()) diff --git a/stdlib/src/fallback.rs b/stdlib/src/fallback.rs index cc95fb993..910962a97 100644 --- a/stdlib/src/fallback.rs +++ b/stdlib/src/fallback.rs @@ -42,6 +42,7 @@ pub(crate) fn bit_decomposition( a: gate.clone(), b: bit_vector.clone(), radix: 2, + is_little_endian: true, })); // Now apply constraints to the bits such that they are the bit decomposition