From 800e2d44aa175c57698de739e5839eb7af03498a Mon Sep 17 00:00:00 2001 From: Nam Ngo Date: Thu, 27 Jun 2024 09:48:35 +0700 Subject: [PATCH] fix missing constants --- input/circuit.circom | 59 +++++++++++++++++++++++++++++++++++--------- src/compiler.rs | 2 +- src/main.rs | 6 +++++ 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/input/circuit.circom b/input/circuit.circom index b9a95f4..60c1681 100644 --- a/input/circuit.circom +++ b/input/circuit.circom @@ -1,20 +1,57 @@ -pragma circom 2.1.0; +// from 0xZKML/zk-mnist -template componentA () { - signal input in[2][2]; - signal output out; +pragma circom 2.0.0; + +template Switcher() { + signal input sel; + signal input L; + signal input R; + signal output outL; + signal output outR; - out <== in[0][0] + in[0][1] + in[1][0] + in[1][1]; + signal aux; + + aux <== (R-L)*sel; // We create aux in order to have only one multiplication + outL <== aux + L; + outR <== -aux + R; } -template componentB() { - signal input a_in[2][2]; +template ArgMax (n) { + signal input in[n]; signal output out; - component a = componentA(); - a.in <== a_in; + // assert (out < n); + signal gts[n]; // store comparators + component switchers[n+1]; // switcher for comparing maxs + component aswitchers[n+1]; // switcher for arg max + + signal maxs[n+1]; + signal amaxs[n+1]; - out <== a.out; + maxs[0] <== in[0]; + amaxs[0] <== 0; + for(var i = 0; i < n; i++) { + gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility + switchers[i+1] = Switcher(); + aswitchers[i+1] = Switcher(); + + switchers[i+1].sel <== gts[i]; + switchers[i+1].L <== maxs[i]; + switchers[i+1].R <== in[i]; + + aswitchers[i+1].sel <== gts[i]; + aswitchers[i+1].L <== amaxs[i]; + aswitchers[i+1].R <== i; + amaxs[i+1] <== aswitchers[i+1].outL; + maxs[i+1] <== switchers[i+1].outL; + } + + out <== amaxs[n]; } -component main = componentB(); +component main = ArgMax(1); + +/* INPUT = { + "in": ["2","3","1","5","4"], + "out": "3" +} */ \ No newline at end of file diff --git a/src/compiler.rs b/src/compiler.rs index 1edf535..93c39a7 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -404,7 +404,7 @@ impl Compiler { if let Some(value) = signal.value { constant_to_node_id_and_value - .insert(signal.name.clone(), (*node_id, value.to_string())); + .insert(format!("{}_{}",signal.name.clone(),signal_id), (*node_id, value.to_string())); } } } diff --git a/src/main.rs b/src/main.rs index 404bf95..426d8f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,6 +29,12 @@ fn main() -> Result<(), ProgramError> { let output_file_path = build_output(&output_dir, "circuit", "txt"); circuit.write_bristol(&mut File::create(output_file_path)?)?; + // let output_file_path_json = build_output(&output_dir, "circuit", "json"); + // File::create(output_file_path_json)?.write_all(serde_json::to_string_pretty(&circuit)?.as_bytes())?; + + let output_debug_path_json = build_output(&output_dir, "debug", "json"); + File::create(output_debug_path_json)?.write_all(serde_json::to_string_pretty(&compiler)?.as_bytes())?; + let output_file_path = build_output(&output_dir, "circuit_info", "json"); File::create(output_file_path)?.write_all(to_string_pretty(&circuit.info)?.as_bytes())?;