diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 250d1efb0f6800..7c12debce469e0 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -18,15 +18,14 @@ use { /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { - arr: Vec, // Underlying array implementing binary indexed tree. - sum: T, // Current sum of weights, excluding already selected indices. + // Underlying array implementing binary tree. + // tree[i] is the sum of weights in the left sub-tree of node i. + tree: Vec, + // Current sum of all weights, excluding already sampled ones. + weight: T, zeros: Vec, // Indices of zero weighted entries. } -// The implementation uses binary indexed tree: -// https://en.wikipedia.org/wiki/Fenwick_tree -// to maintain cumulative sum of weights excluding already selected indices -// over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, @@ -34,36 +33,39 @@ where /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { - let size = weights.len() + 1; let zero = ::default(); - let mut arr = vec![zero; size]; + let mut tree = vec![zero; get_tree_size(weights.len())]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; let mut num_overflow = 0; - for (mut k, &weight) in (1usize..).zip(weights) { + for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { - zeros.push(k - 1); + zeros.push(k); num_negative += 1; continue; } if weight == zero { - zeros.push(k - 1); + zeros.push(k); continue; } sum = match sum.checked_add(&weight) { Some(val) => val, None => { - zeros.push(k - 1); + zeros.push(k); num_overflow += 1; continue; } }; - while k < size { - arr[k] += weight; - k += k & k.wrapping_neg(); + let mut index = tree.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + tree[index] += weight; + } } } if num_negative > 0 { @@ -72,7 +74,11 @@ where if num_overflow > 0 { datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); } - Self { arr, sum, zeros } + Self { + tree, + weight: sum, + zeros, + } } } @@ -80,54 +86,65 @@ impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { - // Returns cumulative sum of current weights upto index k (inclusive). - fn cumsum(&self, mut k: usize) -> T { - let mut out = ::default(); - while k != 0 { - out += self.arr[k]; - k ^= k & k.wrapping_neg(); - } - out - } - // Removes given weight at index k. - fn remove(&mut self, mut k: usize, weight: T) { - self.sum -= weight; - let size = self.arr.len(); - while k < size { - self.arr[k] -= weight; - k += k & k.wrapping_neg(); + fn remove(&mut self, k: usize, weight: T) { + self.weight -= weight; + let mut index = self.tree.len() + k; + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + self.tree[index] -= weight; + } } } - // Returns smallest index such that self.cumsum(k) > val, + // Returns smallest index such that cumsum of weights[..=k] > val, // along with its respective weight. - fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); - debug_assert!(val < self.sum); - let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); - let mut hi = (self.arr.len() - 1, self.sum); - while lo.0 + 1 < hi.0 { - let k = lo.0 + (hi.0 - lo.0) / 2; - let sum = self.cumsum(k); - if sum <= val { - lo = (k, sum); + debug_assert!(val < self.weight); + let mut index = 0; + let mut weight = self.weight; + while index < self.tree.len() { + if val < self.tree[index] { + weight = self.tree[index]; + index = (index << 1) + 1; } else { - hi = (k, sum); + weight -= self.tree[index]; + val -= self.tree[index]; + index = (index << 1) + 2; } } - debug_assert!(lo.1 <= val); - debug_assert!(hi.1 > val); - (hi.0, hi.1 - lo.1) + (index - self.tree.len(), weight) } - pub fn remove_index(&mut self, index: usize) { - let zero = ::default(); - let weight = self.cumsum(index + 1) - self.cumsum(index); - if weight != zero { - self.remove(index + 1, weight); - } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { + pub fn remove_index(&mut self, k: usize) { + let mut index = self.tree.len() + k; + let mut weight = ::default(); // zero + while index != 0 { + let offset = index & 1; + index = (index - 1) >> 1; + if offset > 0 { + if self.tree[index] != weight { + self.remove(k, self.tree[index] - weight); + } else { + self.remove_zero(k); + } + return; + } + weight += self.tree[index]; + } + if self.weight != weight { + self.remove(k, self.weight - weight); + } else { + self.remove_zero(k); + } + } + + fn remove_zero(&mut self, k: usize) { + if let Some(index) = self.zeros.iter().position(|&ix| ix == k) { self.zeros.remove(index); } } @@ -140,10 +157,10 @@ where // Equivalent to weighted_shuffle.shuffle(&mut rng).next() pub fn first(&self, rng: &mut R) -> Option { let zero = ::default(); - if self.sum > zero { - let sample = ::Sampler::sample_single(zero, self.sum, rng); + if self.weight > zero { + let sample = ::Sampler::sample_single(zero, self.weight, rng); let (index, _weight) = WeightedShuffle::search(self, sample); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -160,11 +177,11 @@ where pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { std::iter::from_fn(move || { let zero = ::default(); - if self.sum > zero { - let sample = ::Sampler::sample_single(zero, self.sum, rng); + if self.weight > zero { + let sample = ::Sampler::sample_single(zero, self.weight, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -176,6 +193,19 @@ where } } +// Maps number of items to the "internal" size of the binary tree "implicitly" +// holding those items on the leaves. +fn get_tree_size(count: usize) -> usize { + let shift = usize::BITS + - count.leading_zeros() + - if count.is_power_of_two() && count != 1 { + 1 + } else { + 0 + }; + (1usize << shift) - 1 +} + #[cfg(test)] mod tests { use { @@ -218,6 +248,30 @@ mod tests { shuffle } + #[test] + fn test_get_tree_size() { + assert_eq!(get_tree_size(0), 0); + assert_eq!(get_tree_size(1), 1); + assert_eq!(get_tree_size(2), 1); + assert_eq!(get_tree_size(3), 3); + assert_eq!(get_tree_size(4), 3); + for count in 5..9 { + assert_eq!(get_tree_size(count), 7); + } + for count in 9..17 { + assert_eq!(get_tree_size(count), 15); + } + for count in 17..33 { + assert_eq!(get_tree_size(count), 31); + } + assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1); + assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1); + assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1); + assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1); + } + // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { @@ -357,4 +411,20 @@ mod tests { assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } + + #[test] + fn test_weighted_shuffle_paranoid() { + let mut rng = rand::thread_rng(); + for size in 0..1351 { + let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect(); + let seed = rng.gen::<[u8; 32]>(); + let mut rng = ChaChaRng::from_seed(seed); + let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone()); + let shuffle = WeightedShuffle::new("", &weights); + if size > 0 { + assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0])); + } + assert_eq!(shuffle.shuffle(&mut rng).collect::>(), shuffle_slow); + } + } }