diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fac2075792..04c3d0ea555 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * `Address`: optimize `functionCall` functions by checking contract size only if there is no returned data. ([#3469](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3469)) * `GovernorCompatibilityBravo`: remove unused `using` statements. ([#3506](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3506)) * `ERC20`: optimize `_transfer`, `_mint` and `_burn` by using `unchecked` arithmetic when possible. ([#3513](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3513)) + * `ERC20Votes`, `ERC721Votes`: optimize `getPastVotes` for looking up recent checkpoints. ([#3673](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3673)) * `ERC20FlashMint`: add an internal `_flashFee` function for overriding. ([#3551](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3551)) * `ERC4626`: use the same `decimals()` as the underlying asset by default (if available). ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639)) * `ERC4626`: add internal `_initialConvertToShares` and `_initialConvertToAssets` functions to customize empty vaults behavior. ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639)) diff --git a/contracts/governance/utils/Votes.sol b/contracts/governance/utils/Votes.sol index b0cdbb42fec..053bfa9e8e4 100644 --- a/contracts/governance/utils/Votes.sol +++ b/contracts/governance/utils/Votes.sol @@ -56,7 +56,7 @@ abstract contract Votes is IVotes, Context, EIP712 { * - `blockNumber` must have been already mined */ function getPastVotes(address account, uint256 blockNumber) public view virtual override returns (uint256) { - return _delegateCheckpoints[account].getAtBlock(blockNumber); + return _delegateCheckpoints[account].getAtProbablyRecentBlock(blockNumber); } /** @@ -72,7 +72,7 @@ abstract contract Votes is IVotes, Context, EIP712 { */ function getPastTotalSupply(uint256 blockNumber) public view virtual override returns (uint256) { require(blockNumber < block.number, "Votes: block not yet mined"); - return _totalCheckpoints.getAtBlock(blockNumber); + return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber); } /** diff --git a/contracts/mocks/CheckpointsMock.sol b/contracts/mocks/CheckpointsMock.sol index 591d2fc3029..d630f2e3443 100644 --- a/contracts/mocks/CheckpointsMock.sol +++ b/contracts/mocks/CheckpointsMock.sol @@ -22,8 +22,8 @@ contract CheckpointsMock { return _totalCheckpoints.getAtBlock(blockNumber); } - function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) { - return _totalCheckpoints.getAtRecentBlock(blockNumber); + function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) { + return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber); } function length() public view returns (uint256) { @@ -52,10 +52,6 @@ contract Checkpoints224Mock { return _totalCheckpoints.upperLookup(key); } - function upperLookupRecent(uint32 key) public view returns (uint224) { - return _totalCheckpoints.upperLookupRecent(key); - } - function length() public view returns (uint256) { return _totalCheckpoints._checkpoints.length; } @@ -82,10 +78,6 @@ contract Checkpoints160Mock { return _totalCheckpoints.upperLookup(key); } - function upperLookupRecent(uint96 key) public view returns (uint224) { - return _totalCheckpoints.upperLookupRecent(key); - } - function length() public view returns (uint256) { return _totalCheckpoints._checkpoints.length; } diff --git a/contracts/token/ERC20/extensions/ERC20Votes.sol b/contracts/token/ERC20/extensions/ERC20Votes.sol index c0e88bc19e8..0ce489927fb 100644 --- a/contracts/token/ERC20/extensions/ERC20Votes.sol +++ b/contracts/token/ERC20/extensions/ERC20Votes.sol @@ -97,6 +97,7 @@ abstract contract ERC20Votes is IVotes, ERC20Permit { function _checkpointsLookup(Checkpoint[] storage ckpts, uint256 blockNumber) private view returns (uint256) { // We run a binary search to look for the earliest checkpoint taken after `blockNumber`. // + // Initially we check if the block is recent to narrow the search range. // During the loop, the index of the wanted checkpoint remains in the range [low-1, high). // With each iteration, either `low` or `high` is moved towards the middle of the range to maintain the invariant. // - If the middle checkpoint is after `blockNumber`, we look in [low, mid) @@ -106,18 +107,30 @@ abstract contract ERC20Votes is IVotes, ERC20Permit { // Note that if the latest checkpoint available is exactly for `blockNumber`, we end up with an index that is // past the end of the array, so we technically don't find a checkpoint after `blockNumber`, but it works out // the same. - uint256 high = ckpts.length; + uint256 length = ckpts.length; + uint256 low = 0; + uint256 high = length; + + if (length > 5) { + uint256 mid = length - Math.sqrt(length); + if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) { + high = mid; + } else { + low = mid + 1; + } + } + while (low < high) { uint256 mid = Math.average(low, high); - if (ckpts[mid].fromBlock > blockNumber) { + if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) { high = mid; } else { low = mid + 1; } } - return high == 0 ? 0 : ckpts[high - 1].votes; + return high == 0 ? 0 : _unsafeAccess(ckpts, high - 1).votes; } /** @@ -229,11 +242,14 @@ abstract contract ERC20Votes is IVotes, ERC20Permit { uint256 delta ) private returns (uint256 oldWeight, uint256 newWeight) { uint256 pos = ckpts.length; - oldWeight = pos == 0 ? 0 : ckpts[pos - 1].votes; + + Checkpoint memory oldCkpt = pos == 0 ? Checkpoint(0, 0) : _unsafeAccess(ckpts, pos - 1); + + oldWeight = oldCkpt.votes; newWeight = op(oldWeight, delta); - if (pos > 0 && ckpts[pos - 1].fromBlock == block.number) { - ckpts[pos - 1].votes = SafeCast.toUint224(newWeight); + if (pos > 0 && oldCkpt.fromBlock == block.number) { + _unsafeAccess(ckpts, pos - 1).votes = SafeCast.toUint224(newWeight); } else { ckpts.push(Checkpoint({fromBlock: SafeCast.toUint32(block.number), votes: SafeCast.toUint224(newWeight)})); } @@ -246,4 +262,11 @@ abstract contract ERC20Votes is IVotes, ERC20Permit { function _subtract(uint256 a, uint256 b) private pure returns (uint256) { return a - b; } + + function _unsafeAccess(Checkpoint[] storage ckpts, uint256 pos) private view returns (Checkpoint storage result) { + assembly { + mstore(0, ckpts.slot) + result.slot := add(keccak256(0, 0x20), pos) + } + } } diff --git a/contracts/utils/Checkpoints.sol b/contracts/utils/Checkpoints.sol index 1692a6d15dc..6199d8d7773 100644 --- a/contracts/utils/Checkpoints.sol +++ b/contracts/utils/Checkpoints.sol @@ -49,22 +49,28 @@ library Checkpoints { /** * @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one - * before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search - * key is known to be recent. + * before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched + * checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of + * checkpoints. */ - function getAtRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) { + function getAtProbablyRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) { require(blockNumber < block.number, "Checkpoints: block not yet mined"); uint32 key = SafeCast.toUint32(blockNumber); uint256 length = self._checkpoints.length; - uint256 offset = 1; - while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._blockNumber > key) { - offset <<= 1; + uint256 low = 0; + uint256 high = length; + + if (length > 5) { + uint256 mid = length - Math.sqrt(length); + if (key < _unsafeAccess(self._checkpoints, mid)._blockNumber) { + high = mid; + } else { + low = mid + 1; + } } - uint256 low = offset < length ? length - offset : 0; - uint256 high = length - (offset >> 1); uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high); return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; @@ -225,25 +231,6 @@ library Checkpoints { return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; } - /** - * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to - * {upperLookup}), optimized for the case when the search key is known to be recent. - */ - function upperLookupRecent(Trace224 storage self, uint32 key) internal view returns (uint224) { - uint256 length = self._checkpoints.length; - uint256 offset = 1; - - while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) { - offset <<= 1; - } - - uint256 low = 0 < offset && offset < length ? length - offset : 0; - uint256 high = length - (offset >> 1); - uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high); - - return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; - } - /** * @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint, * or by updating the last one. @@ -380,25 +367,6 @@ library Checkpoints { return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; } - /** - * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to - * {upperLookup}), optimized for the case when the search key is known to be recent. - */ - function upperLookupRecent(Trace160 storage self, uint96 key) internal view returns (uint160) { - uint256 length = self._checkpoints.length; - uint256 offset = 1; - - while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) { - offset <<= 1; - } - - uint256 low = 0 < offset && offset < length ? length - offset : 0; - uint256 high = length - (offset >> 1); - uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high); - - return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; - } - /** * @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint, * or by updating the last one. diff --git a/scripts/generate/templates/Checkpoints.js b/scripts/generate/templates/Checkpoints.js index b6a9ce04c1d..e0bdd2055ee 100644 --- a/scripts/generate/templates/Checkpoints.js +++ b/scripts/generate/templates/Checkpoints.js @@ -70,25 +70,6 @@ function upperLookup(${opts.historyTypeName} storage self, ${opts.keyTypeName} k uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, 0, length); return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName}; } - -/** - * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to - * {upperLookup}), optimized for the case when the search key is known to be recent. - */ -function upperLookupRecent(${opts.historyTypeName} storage self, ${opts.keyTypeName} key) internal view returns (${opts.valueTypeName}) { - uint256 length = self.${opts.checkpointFieldName}.length; - uint256 offset = 1; - - while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) { - offset <<= 1; - } - - uint256 low = 0 < offset && offset < length ? length - offset : 0; - uint256 high = length - (offset >> 1); - uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high); - - return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName}; -} `; const legacyOperations = opts => `\ @@ -115,22 +96,28 @@ function getAtBlock(${opts.historyTypeName} storage self, uint256 blockNumber) i /** * @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one - * before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search - * key is known to be recent. + * before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched + * checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of + * checkpoints. */ -function getAtRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) { +function getAtProbablyRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) { require(blockNumber < block.number, "Checkpoints: block not yet mined"); uint32 key = SafeCast.toUint32(blockNumber); uint256 length = self.${opts.checkpointFieldName}.length; - uint256 offset = 1; - while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) { - offset <<= 1; + uint256 low = 0; + uint256 high = length; + + if (length > 5) { + uint256 mid = length - Math.sqrt(length); + if (key < _unsafeAccess(self.${opts.checkpointFieldName}, mid)._blockNumber) { + high = mid; + } else { + low = mid + 1; + } } - uint256 low = offset < length ? length - offset : 0; - uint256 high = length - (offset >> 1); uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high); return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName}; diff --git a/scripts/generate/templates/CheckpointsMock.js b/scripts/generate/templates/CheckpointsMock.js index 6ce8e534b0b..2feb112409e 100755 --- a/scripts/generate/templates/CheckpointsMock.js +++ b/scripts/generate/templates/CheckpointsMock.js @@ -26,8 +26,8 @@ contract CheckpointsMock { return _totalCheckpoints.getAtBlock(blockNumber); } - function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) { - return _totalCheckpoints.getAtRecentBlock(blockNumber); + function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) { + return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber); } function length() public view returns (uint256) { @@ -58,10 +58,6 @@ contract Checkpoints${length}Mock { return _totalCheckpoints.upperLookup(key); } - function upperLookupRecent(uint${256 - length} key) public view returns (uint224) { - return _totalCheckpoints.upperLookupRecent(key); - } - function length() public view returns (uint256) { return _totalCheckpoints._checkpoints.length; } diff --git a/test/token/ERC20/extensions/ERC20Votes.test.js b/test/token/ERC20/extensions/ERC20Votes.test.js index be28f66f304..9d3160bd18c 100644 --- a/test/token/ERC20/extensions/ERC20Votes.test.js +++ b/test/token/ERC20/extensions/ERC20Votes.test.js @@ -56,6 +56,19 @@ contract('ERC20Votes', function (accounts) { ); }); + it('recent checkpoints', async function () { + await this.token.delegate(holder, { from: holder }); + for (let i = 0; i < 6; i++) { + await this.token.mint(holder, 1); + } + const block = await web3.eth.getBlockNumber(); + expect(await this.token.numCheckpoints(holder)).to.be.bignumber.equal('6'); + // recent + expect(await this.token.getPastVotes(holder, block - 1)).to.be.bignumber.equal('5'); + // non-recent + expect(await this.token.getPastVotes(holder, block - 6)).to.be.bignumber.equal('0'); + }); + describe('set delegation', function () { describe('call', function () { it('delegation with balance', async function () { diff --git a/test/utils/Checkpoints.test.js b/test/utils/Checkpoints.test.js index af8a1a13bb9..28525729f0d 100644 --- a/test/utils/Checkpoints.test.js +++ b/test/utils/Checkpoints.test.js @@ -22,8 +22,10 @@ contract('Checkpoints', function (accounts) { it('returns zero as past value', async function () { await time.advanceBlock(); - expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0'); - expect(await this.checkpoint.getAtRecentBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0'); + expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1)) + .to.be.bignumber.equal('0'); + expect(await this.checkpoint.getAtProbablyRecentBlock(await web3.eth.getBlockNumber() - 1)) + .to.be.bignumber.equal('0'); }); }); @@ -41,7 +43,7 @@ contract('Checkpoints', function (accounts) { expect(await this.checkpoint.latest()).to.be.bignumber.equal('3'); }); - for (const fn of [ 'getAtBlock(uint256)', 'getAtRecentBlock(uint256)' ]) { + for (const fn of [ 'getAtBlock(uint256)', 'getAtProbablyRecentBlock(uint256)' ]) { describe(`lookup: ${fn}`, function () { it('returns past values', async function () { expect(await this.checkpoint.methods[fn](this.tx1.receipt.blockNumber - 1)).to.be.bignumber.equal('0'); @@ -78,6 +80,18 @@ contract('Checkpoints', function (accounts) { expect(await this.checkpoint.length()).to.be.bignumber.equal(lengthBefore.addn(1)); expect(await this.checkpoint.latest()).to.be.bignumber.equal('10'); }); + + it('more than 5 checkpoints', async function () { + for (let i = 4; i <= 6; i++) { + await this.checkpoint.push(i); + } + expect(await this.checkpoint.length()).to.be.bignumber.equal('6'); + const block = await web3.eth.getBlockNumber(); + // recent + expect(await this.checkpoint.getAtProbablyRecentBlock(block - 1)).to.be.bignumber.equal('5'); + // non-recent + expect(await this.checkpoint.getAtProbablyRecentBlock(block - 9)).to.be.bignumber.equal('0'); + }); }); }); @@ -95,7 +109,6 @@ contract('Checkpoints', function (accounts) { it('lookup returns 0', async function () { expect(await this.contract.lowerLookup(0)).to.be.bignumber.equal('0'); expect(await this.contract.upperLookup(0)).to.be.bignumber.equal('0'); - expect(await this.contract.upperLookupRecent(0)).to.be.bignumber.equal('0'); }); }); @@ -149,7 +162,6 @@ contract('Checkpoints', function (accounts) { const value = last(this.checkpoints.filter(x => i >= x.key))?.value || '0'; expect(await this.contract.upperLookup(i)).to.be.bignumber.equal(value); - expect(await this.contract.upperLookupRecent(i)).to.be.bignumber.equal(value); } }); });