diff --git a/contracts/PRBMathCommon.sol b/contracts/PRBMathCommon.sol index 5efc656b..28e05361 100644 --- a/contracts/PRBMathCommon.sol +++ b/contracts/PRBMathCommon.sol @@ -301,6 +301,56 @@ library PRBMathCommon { } } + /// @notice Calculates floor(x*y÷denominator) with full precision. + /// + /// @dev An extension of "mulDiv" for signed numbers. Works by computing the signs and the absolute values separately. + /// + /// Requirements: + /// - None of the inputs can be type(int256).min. + /// - The result must fit within int256. + /// + /// @param x The multiplicand as an int256. + /// @param y The multiplier as an int256. + /// @param denominator The divisor as an int256. + /// @return result The result as an int256. + function mulDivSigned( + int256 x, + int256 y, + int256 denominator + ) internal pure returns (int256 result) { + require(x > type(int256).min); + require(y > type(int256).min); + require(denominator > type(int256).min); + + // Get hold of the absolute values of x, y and the denominator. + uint256 ax; + uint256 ay; + uint256 ad; + unchecked { + ax = x < 0 ? uint256(-x) : uint256(x); + ay = y < 0 ? uint256(-y) : uint256(y); + ad = denominator < 0 ? uint256(-denominator) : uint256(denominator); + } + + // Compute the absolute value of (x*y)÷denominator. The result must fit within int256. + uint256 resultUnsigned = mulDiv(ax, ay, ad); + require(resultUnsigned <= uint256(type(int256).max)); + + // Get the signs of x, y and the denominator. + uint256 sx; + uint256 sy; + uint256 sd; + assembly { + sx := sgt(x, sub(0, 1)) + sy := sgt(y, sub(0, 1)) + sd := sgt(denominator, sub(0, 1)) + } + + // XOR over sx, sy and sd. This is checking whether there are one or three negative signs in the inputs. + // If yes, the result should be negative. + result = sx ^ sy ^ sd == 0 ? -int256(resultUnsigned) : int256(resultUnsigned); + } + /// @notice Calculates the square root of x, rounding down. /// @dev Uses the Babylonian method https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method. ///