From 95b71cc07f37b21ea37474538a0a959600c47eb6 Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Tue, 15 Sep 2020 12:28:09 +1000 Subject: [PATCH] Add `safe_sum` and use it in state_processing --- consensus/safe_arith/src/iter.rs | 70 +++++++++++++++++++ consensus/safe_arith/src/lib.rs | 5 +- .../per_epoch_processing/process_slashings.rs | 4 +- 3 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 consensus/safe_arith/src/iter.rs diff --git a/consensus/safe_arith/src/iter.rs b/consensus/safe_arith/src/iter.rs new file mode 100644 index 00000000000..1fc3d3a1a7a --- /dev/null +++ b/consensus/safe_arith/src/iter.rs @@ -0,0 +1,70 @@ +use crate::{Result, SafeArith}; + +/// Extension trait for iterators, providing a safe replacement for `sum`. +pub trait SafeArithIter { + fn safe_sum(self) -> Result; +} + +impl SafeArithIter for I +where + I: Iterator + Sized, + T: SafeArith, +{ + fn safe_sum(mut self) -> Result { + self.try_fold(T::ZERO, |acc, x| acc.safe_add(x)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::ArithError; + + #[test] + fn empty_sum() { + let v: Vec = vec![]; + assert_eq!(v.into_iter().safe_sum(), Ok(0)); + } + + #[test] + fn unsigned_sum_small() { + let v = vec![400u64, 401, 402, 403, 404, 405, 406]; + assert_eq!( + v.iter().copied().safe_sum().unwrap(), + v.iter().copied().sum() + ); + } + + #[test] + fn unsigned_sum_overflow() { + let v = vec![u64::MAX, 1]; + assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow)); + } + + #[test] + fn signed_sum_small() { + let v = vec![-1i64, -2i64, -3i64, 3, 2, 1]; + assert_eq!(v.into_iter().safe_sum(), Ok(0)); + } + + #[test] + fn signed_sum_overflow_above() { + let v = vec![1, 2, 3, 4, i16::MAX, 0, 1, 2, 3]; + assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow)); + } + + #[test] + fn signed_sum_overflow_below() { + let v = vec![i16::MIN, -1]; + assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow)); + } + + #[test] + fn signed_sum_almost_overflow() { + let v = vec![i64::MIN, 1, -1i64, i64::MAX, i64::MAX, 1]; + assert_eq!( + v.iter().copied().safe_sum().unwrap(), + v.iter().copied().sum() + ); + } +} diff --git a/consensus/safe_arith/src/lib.rs b/consensus/safe_arith/src/lib.rs index 90387b22387..2275682109b 100644 --- a/consensus/safe_arith/src/lib.rs +++ b/consensus/safe_arith/src/lib.rs @@ -1,4 +1,7 @@ //! Library for safe arithmetic on integers, avoiding overflow and division by zero. +mod iter; + +pub use iter::SafeArithIter; /// Error representing the failure of an arithmetic operation. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -7,7 +10,7 @@ pub enum ArithError { DivisionByZero, } -type Result = std::result::Result; +pub type Result = std::result::Result; macro_rules! assign_method { ($name:ident, $op:ident, $doc_op:expr) => { diff --git a/consensus/state_processing/src/per_epoch_processing/process_slashings.rs b/consensus/state_processing/src/per_epoch_processing/process_slashings.rs index 7a2c94b8506..4901d303063 100644 --- a/consensus/state_processing/src/per_epoch_processing/process_slashings.rs +++ b/consensus/state_processing/src/per_epoch_processing/process_slashings.rs @@ -1,4 +1,4 @@ -use safe_arith::SafeArith; +use safe_arith::{SafeArith, SafeArithIter}; use types::{BeaconStateError as Error, *}; /// Process slashings. @@ -10,7 +10,7 @@ pub fn process_slashings( spec: &ChainSpec, ) -> Result<(), Error> { let epoch = state.current_epoch(); - let sum_slashings = state.get_all_slashings().iter().sum::(); + let sum_slashings = state.get_all_slashings().iter().copied().safe_sum()?; for (index, validator) in state.validators.iter().enumerate() { if validator.slashed