Skip to content

Commit

Permalink
[zk-token-sdk] Remove std::thread from wasm target (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
samkim-crypto authored Mar 26, 2024
1 parent 1261f1f commit a3bc406
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 29 deletions.
76 changes: 47 additions & 29 deletions zk-token-sdk/src/encryption/discrete_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#![cfg(not(target_os = "solana"))]

#[cfg(not(target_arch = "wasm32"))]
use std::thread;
use {
crate::RISTRETTO_POINT_LEN,
curve25519_dalek::{
Expand All @@ -26,14 +28,15 @@ use {
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::{collections::HashMap, thread},
std::collections::HashMap,
thiserror::Error,
};

const TWO16: u64 = 65536; // 2^16
const TWO17: u64 = 131072; // 2^17

/// Maximum number of threads permitted for discrete log computation
#[cfg(not(target_arch = "wasm32"))]
const MAX_THREAD: usize = 65536;

#[derive(Error, Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -112,6 +115,7 @@ impl DiscreteLog {
}

/// Adjusts number of threads in a discrete log instance.
#[cfg(not(target_arch = "wasm32"))]
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
// number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD {
Expand Down Expand Up @@ -141,35 +145,48 @@ impl DiscreteLog {
/// Solves the discrete log problem under the assumption that the solution
/// is a positive 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

let mut solution = None;
for handle in handles {
let discrete_log = handle.join().unwrap();
if discrete_log.is_some() {
solution = discrete_log;
}
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
{
let ristretto_iterator = RistrettoIterator::new(
(self.target, 0_u64),
(-(&self.step_point), self.num_threads as u64),
);

Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
}
solution
}

fn decode_range(
Expand Down Expand Up @@ -274,6 +291,7 @@ mod tests {
println!("single thread discrete log computation secs: {computation_secs:?} sec");
}

#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_decode_correctness_threaded() {
// general case
Expand Down
1 change: 1 addition & 0 deletions zk-token-sdk/src/encryption/elgamal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ mod tests {
assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap());
}

#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_encrypt_decrypt_correctness_multithreaded() {
let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand();
Expand Down

0 comments on commit a3bc406

Please sign in to comment.