diff --git a/script/DeployDeterministicInvoiceModule.s.sol b/script/DeployDeterministicInvoiceModule.s.sol index bcc42bb1..a8d30819 100644 --- a/script/DeployDeterministicInvoiceModule.s.sol +++ b/script/DeployDeterministicInvoiceModule.s.sol @@ -15,11 +15,13 @@ contract DeployDeterministicInvoiceModule is BaseScript { string memory create2Salt, ISablierV2LockupLinear sablierLockupLinear, ISablierV2LockupTranched sablierLockupTranched, - address brokerAdmin + address brokerAdmin, + string memory baseURI ) public virtual broadcast returns (InvoiceModule invoiceModule) { bytes32 salt = bytes32(abi.encodePacked(create2Salt)); // Deterministically deploy the {InvoiceModule} contracts - invoiceModule = new InvoiceModule{ salt: salt }(sablierLockupLinear, sablierLockupTranched, brokerAdmin); + invoiceModule = + new InvoiceModule{ salt: salt }(sablierLockupLinear, sablierLockupTranched, brokerAdmin, baseURI); } } diff --git a/src/modules/invoice-module/InvoiceModule.sol b/src/modules/invoice-module/InvoiceModule.sol index 5b7296f6..d0679a73 100644 --- a/src/modules/invoice-module/InvoiceModule.sol +++ b/src/modules/invoice-module/InvoiceModule.sol @@ -3,6 +3,8 @@ pragma solidity ^0.8.26; import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { ERC721 } from "@openzeppelin/contracts/token/ERC721/ERC721.sol"; +import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; import { ISablierV2LockupLinear } from "@sablier/v2-core/src/interfaces/ISablierV2LockupLinear.sol"; import { ISablierV2LockupTranched } from "@sablier/v2-core/src/interfaces/ISablierV2LockupTranched.sol"; @@ -15,22 +17,23 @@ import { Helpers } from "./libraries/Helpers.sol"; /// @title InvoiceModule /// @notice See the documentation in {IInvoiceModule} -contract InvoiceModule is IInvoiceModule, StreamManager { +contract InvoiceModule is IInvoiceModule, StreamManager, ERC721 { using SafeERC20 for IERC20; + using Strings for uint256; /*////////////////////////////////////////////////////////////////////////// PRIVATE STORAGE //////////////////////////////////////////////////////////////////////////*/ - /// @dev Array with invoice IDs created through the `container` container contract - mapping(address container => uint256[]) private _invoicesOf; - /// @dev Invoice details mapped by the `id` invoice ID mapping(uint256 id => Types.Invoice) private _invoices; /// @dev Counter to keep track of the next ID used to create a new invoice uint256 private _nextInvoiceId; + /// @dev Base URI used to get the ERC-721 `tokenURI` metadata JSON schema + string private _collectionURI; + /*////////////////////////////////////////////////////////////////////////// CONSTRUCTOR //////////////////////////////////////////////////////////////////////////*/ @@ -39,9 +42,17 @@ contract InvoiceModule is IInvoiceModule, StreamManager { constructor( ISablierV2LockupLinear _sablierLockupLinear, ISablierV2LockupTranched _sablierLockupTranched, - address _brokerAdmin - ) StreamManager(_sablierLockupLinear, _sablierLockupTranched, _brokerAdmin) { + address _brokerAdmin, + string memory _URI + ) + StreamManager(_sablierLockupLinear, _sablierLockupTranched, _brokerAdmin) + ERC721("Metadock Invoice NFT", "MD-INVOICES") + { + // Start the invoice IDs from 1 _nextInvoiceId = 1; + + // Set the ERC721 baseURI + _collectionURI = _URI; } /*////////////////////////////////////////////////////////////////////////// @@ -75,7 +86,7 @@ contract InvoiceModule is IInvoiceModule, StreamManager { //////////////////////////////////////////////////////////////////////////*/ /// @inheritdoc IInvoiceModule - function createInvoice(Types.Invoice calldata invoice) external onlyContainer returns (uint256 id) { + function createInvoice(Types.Invoice calldata invoice) external onlyContainer returns (uint256 invoiceId) { // Checks: the amount is non-zero if (invoice.payment.amount == 0) { revert Errors.ZeroPaymentAmount(); @@ -132,11 +143,10 @@ contract InvoiceModule is IInvoiceModule, StreamManager { } // Get the next invoice ID - id = _nextInvoiceId; + invoiceId = _nextInvoiceId; // Effects: create the invoice - _invoices[id] = Types.Invoice({ - recipient: invoice.recipient, + _invoices[invoiceId] = Types.Invoice({ status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -153,16 +163,16 @@ contract InvoiceModule is IInvoiceModule, StreamManager { // Effects: increment the next invoice id // Use unchecked because the invoice id cannot realistically overflow unchecked { - _nextInvoiceId = id + 1; + ++_nextInvoiceId; } - // Effects: add the invoice on the list of invoices generated by the container - _invoicesOf[invoice.recipient].push(id); + // Effects: mint the invoice NFT to the recipient container + _mint({ to: msg.sender, tokenId: invoiceId }); // Log the invoice creation emit InvoiceCreated({ - id: id, - recipient: invoice.recipient, + id: invoiceId, + recipient: msg.sender, status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -175,10 +185,9 @@ contract InvoiceModule is IInvoiceModule, StreamManager { // Load the invoice from storage Types.Invoice memory invoice = _invoices[id]; - // Checks: the invoice is not null - if (invoice.recipient == address(0)) { - revert Errors.InvoiceNull(); - } + // Retrieve the recipient of the invoice + // This will also check if the invoice is minted or not burned + address recipient = ownerOf(id); // Checks: the invoice is not already paid or canceled if (invoice.status == Types.Status.Paid) { @@ -190,14 +199,14 @@ contract InvoiceModule is IInvoiceModule, StreamManager { // Handle the payment workflow depending on the payment method type if (invoice.payment.method == Types.Method.Transfer) { // Effects: pay the invoice and update its status to `Paid` or `Ongoing` depending on the payment type - _payByTransfer(id, invoice); + _payByTransfer(id, invoice, recipient); } else { uint256 streamId; // Check to see whether the invoice must be paid through a linear or tranched stream if (invoice.payment.method == Types.Method.LinearStream) { - streamId = _payByLinearStream(invoice); + streamId = _payByLinearStream(invoice, recipient); } else { - streamId = _payByTranchedStream(invoice); + streamId = _payByTranchedStream(invoice, recipient); } // Effects: update the status of the invoice to `Ongoing` and the stream ID @@ -222,25 +231,27 @@ contract InvoiceModule is IInvoiceModule, StreamManager { revert Errors.InvoiceAlreadyCanceled(); } - // Checks: `msg.sender` is the recipient if dealing with a transfer-based invoice - // or a linear/tranched stream-based invoice which was not paid yet (not streaming) + // Checks: `msg.sender` is the recipient if invoice status is pending // // Notes: // - Once a linear or tranched stream is created, the `msg.sender` is checked in the // {SablierV2Lockup} `cancel` method - if (invoice.payment.method == Types.Method.Transfer || invoice.status == Types.Status.Pending) { - if (invoice.recipient != msg.sender) { + if (invoice.status == Types.Status.Pending) { + // Retrieve the recipient of the invoice + address recipient = ownerOf(id); + + if (recipient != msg.sender) { revert Errors.OnlyInvoiceRecipient(); } } - // Effects: cancel the stream accordingly depending on its type + // Checks, Effects, Interactions: cancel the stream if status is ongoing // // Notes: // - A transfer-based invoice can be canceled directly // - A linear or tranched stream MUST be canceled by calling the `cancel` method on the according // {ISablierV2Lockup} contract else if (invoice.status == Types.Status.Ongoing) { - cancelStream({ streamType: invoice.payment.method, streamId: invoice.payment.streamId }); + _cancelStream({ streamType: invoice.payment.method, streamId: invoice.payment.streamId }); } // Effects: mark the invoice as canceled @@ -251,10 +262,13 @@ contract InvoiceModule is IInvoiceModule, StreamManager { } /// @inheritdoc IInvoiceModule - function withdrawInvoiceStream(uint256 id) external { + function withdrawInvoiceStream(uint256 id) public returns (uint128 withdrawnAmount) { // Load the invoice from storage Types.Invoice memory invoice = _invoices[id]; + // Retrieve the recipient of the invoice + address recipient = ownerOf(id); + // Effects: update the invoice status to `Paid` once the full payment amount has been successfully streamed uint128 streamedAmount = streamedAmountOf({ streamType: invoice.payment.method, streamId: invoice.payment.streamId }); @@ -263,7 +277,39 @@ contract InvoiceModule is IInvoiceModule, StreamManager { } // Check, Effects, Interactions: withdraw from the stream - withdrawStream({ streamType: invoice.payment.method, streamId: invoice.payment.streamId, to: invoice.recipient }); + return + _withdrawStream({ streamType: invoice.payment.method, streamId: invoice.payment.streamId, to: recipient }); + } + + /// @inheritdoc ERC721 + function tokenURI(uint256 tokenId) public view override returns (string memory) { + // Checks: the `tokenId` was minted or is not burned + _requireOwned(tokenId); + + // Create the `tokenURI` by concatenating the `baseURI`, `tokenId` and metadata extension (.json) + string memory baseURI = _baseURI(); + return string.concat(baseURI, tokenId.toString(), ".json"); + } + + /// @inheritdoc ERC721 + function transferFrom(address from, address to, uint256 tokenId) public override { + // Retrieve the invoice details + Types.Invoice memory invoice = _invoices[tokenId]; + + // Checks: the payment request has been accepted and a stream has already been + // created if dealing with a stream-based payment + if (invoice.payment.streamId != 0) { + // Checks and Effects: withdraw the maximum withdrawable amount to the current stream recipient + // and transfer the stream NFT to the new recipient + _withdrawMaxAndTransferStream({ + streamType: invoice.payment.method, + streamId: invoice.payment.streamId, + newRecipient: to + }); + } + + // Checks, Effects and Interactions: transfer the invoice NFT + super.transferFrom(from, to, tokenId); } /*////////////////////////////////////////////////////////////////////////// @@ -271,7 +317,7 @@ contract InvoiceModule is IInvoiceModule, StreamManager { //////////////////////////////////////////////////////////////////////////*/ /// @dev Pays the `id` invoice by transfer - function _payByTransfer(uint256 id, Types.Invoice memory invoice) internal { + function _payByTransfer(uint256 id, Types.Invoice memory invoice, address recipient) internal { // Effects: update the invoice status to `Paid` if the required number of payments has been made // Using unchecked because the number of payments left cannot underflow as the invoice status // will be updated to `Paid` once `paymentLeft` is zero @@ -293,31 +339,34 @@ contract InvoiceModule is IInvoiceModule, StreamManager { } // Interactions: pay the recipient with native token (ETH) - (bool success,) = payable(invoice.recipient).call{ value: invoice.payment.amount }(""); + (bool success,) = payable(recipient).call{ value: invoice.payment.amount }(""); if (!success) revert Errors.NativeTokenPaymentFailed(); } else { // Interactions: pay the recipient with the ERC-20 token IERC20(invoice.payment.asset).safeTransferFrom({ from: msg.sender, - to: address(invoice.recipient), + to: recipient, value: invoice.payment.amount }); } } /// @dev Create the linear stream payment - function _payByLinearStream(Types.Invoice memory invoice) internal returns (uint256 streamId) { + function _payByLinearStream(Types.Invoice memory invoice, address recipient) internal returns (uint256 streamId) { streamId = StreamManager.createLinearStream({ asset: IERC20(invoice.payment.asset), totalAmount: invoice.payment.amount, startTime: invoice.startTime, endTime: invoice.endTime, - recipient: invoice.recipient + recipient: recipient }); } /// @dev Create the tranched stream payment - function _payByTranchedStream(Types.Invoice memory invoice) internal returns (uint256 streamId) { + function _payByTranchedStream( + Types.Invoice memory invoice, + address recipient + ) internal returns (uint256 streamId) { uint40 numberOfTranches = Helpers.computeNumberOfPayments(invoice.payment.recurrence, invoice.endTime - invoice.startTime); @@ -325,7 +374,7 @@ contract InvoiceModule is IInvoiceModule, StreamManager { asset: IERC20(invoice.payment.asset), totalAmount: invoice.payment.amount, startTime: invoice.startTime, - recipient: invoice.recipient, + recipient: recipient, numberOfTranches: numberOfTranches, recurrence: invoice.payment.recurrence }); @@ -353,4 +402,9 @@ contract InvoiceModule is IInvoiceModule, StreamManager { revert Errors.PaymentIntervalTooShortForSelectedRecurrence(); } } + + /// @inheritdoc ERC721 + function _baseURI() internal view override returns (string memory) { + return _collectionURI; + } } diff --git a/src/modules/invoice-module/interfaces/IInvoiceModule.sol b/src/modules/invoice-module/interfaces/IInvoiceModule.sol index 4f5e81b2..14a8b921 100644 --- a/src/modules/invoice-module/interfaces/IInvoiceModule.sol +++ b/src/modules/invoice-module/interfaces/IInvoiceModule.sol @@ -92,5 +92,5 @@ interface IInvoiceModule { /// - reverts if the payment method of the `id` invoice is not linear or tranched stream based /// /// @param id The ID of the invoice - function withdrawInvoiceStream(uint256 id) external; + function withdrawInvoiceStream(uint256 id) external returns (uint128 withdrawnAmount); } diff --git a/src/modules/invoice-module/libraries/Types.sol b/src/modules/invoice-module/libraries/Types.sol index 0254dc2f..5bc14c3f 100644 --- a/src/modules/invoice-module/libraries/Types.sol +++ b/src/modules/invoice-module/libraries/Types.sol @@ -64,7 +64,6 @@ library Types { /// @param payment The payment struct describing the invoice payment struct Invoice { // slot 0 - address recipient; Status status; uint40 startTime; uint40 endTime; diff --git a/src/modules/invoice-module/sablier-v2/StreamManager.sol b/src/modules/invoice-module/sablier-v2/StreamManager.sol index 2c343a69..bf27b56c 100644 --- a/src/modules/invoice-module/sablier-v2/StreamManager.sol +++ b/src/modules/invoice-module/sablier-v2/StreamManager.sol @@ -61,13 +61,30 @@ abstract contract StreamManager is IStreamManager { } /*////////////////////////////////////////////////////////////////////////// - MODIFIERS + CONSTANT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ - /// @notice Reverts if the `msg.sender` is not the broker admin account or contract - modifier onlyBrokerAdmin() { - if (msg.sender != brokerAdmin) revert Errors.OnlyBrokerAdmin(); - _; + /// @inheritdoc IStreamManager + function getLinearStream(uint256 streamId) public view returns (LockupLinear.StreamLL memory stream) { + stream = LOCKUP_LINEAR.getStream(streamId); + } + + /// @inheritdoc IStreamManager + function getTranchedStream(uint256 streamId) public view returns (LockupTranched.StreamLT memory stream) { + stream = LOCKUP_TRANCHED.getStream(streamId); + } + + /// @inheritdoc IStreamManager + function withdrawableAmountOf( + Types.Method streamType, + uint256 streamId + ) public view returns (uint128 withdrawableAmount) { + withdrawableAmount = _getISablierV2Lockup(streamType).withdrawableAmountOf(streamId); + } + + /// @inheritdoc IStreamManager + function streamedAmountOf(Types.Method streamType, uint256 streamId) public view returns (uint128 streamedAmount) { + streamedAmount = _getISablierV2Lockup(streamType).streamedAmountOf(streamId); } /*////////////////////////////////////////////////////////////////////////// @@ -112,7 +129,10 @@ abstract contract StreamManager is IStreamManager { } /// @inheritdoc IStreamManager - function updateStreamBrokerFee(UD60x18 newBrokerFee) public onlyBrokerAdmin { + function updateStreamBrokerFee(UD60x18 newBrokerFee) public { + // Checks: the `msg.sender` is the broker admin + if (msg.sender != brokerAdmin) revert Errors.OnlyBrokerAdmin(); + // Log the broker fee update emit BrokerFeeUpdated({ oldFee: brokerFee, newFee: newBrokerFee }); @@ -121,69 +141,11 @@ abstract contract StreamManager is IStreamManager { } /*////////////////////////////////////////////////////////////////////////// - WITHDRAW FUNCTIONS + INTERNAL MANAGEMENT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ - /// @inheritdoc IStreamManager - function withdrawStream( - Types.Method streamType, - uint256 streamId, - address to - ) public returns (uint128 withdrawnAmount) { - // Set the according {ISablierV2Lockup} based on the stream type - ISablierV2Lockup sablier = _getISablierV2Lockup(streamType); - - // Withdraw the maximum withdrawable amount - withdrawnAmount = _withdrawStream(sablier, streamId, to); - } - - /// @inheritdoc IStreamManager - function withdrawableAmountOf( - Types.Method streamType, - uint256 streamId - ) public view returns (uint128 withdrawableAmount) { - withdrawableAmount = _getISablierV2Lockup(streamType).withdrawableAmountOf(streamId); - } - - /// @inheritdoc IStreamManager - function streamedAmountOf(Types.Method streamType, uint256 streamId) public view returns (uint128 streamedAmount) { - streamedAmount = _getISablierV2Lockup(streamType).streamedAmountOf(streamId); - } - - /*////////////////////////////////////////////////////////////////////////// - CANCEL FUNCTIONS - //////////////////////////////////////////////////////////////////////////*/ - - /// @inheritdoc IStreamManager - function cancelStream(Types.Method streamType, uint256 streamId) public { - // Set the according {ISablierV2Lockup} based on the stream type - ISablierV2Lockup sablier = _getISablierV2Lockup(streamType); - - // Checks, Effect, Interactions - _cancelStream(sablier, streamId); - } - - /*////////////////////////////////////////////////////////////////////////// - CONSTANT FUNCTIONS - //////////////////////////////////////////////////////////////////////////*/ - - /// @inheritdoc IStreamManager - function getLinearStream(uint256 streamId) public view returns (LockupLinear.StreamLL memory stream) { - stream = LOCKUP_LINEAR.getStream(streamId); - } - - /// @inheritdoc IStreamManager - function getTranchedStream(uint256 streamId) public view returns (LockupTranched.StreamLT memory stream) { - stream = LOCKUP_TRANCHED.getStream(streamId); - } - - /*////////////////////////////////////////////////////////////////////////// - INTERNAL FUNCTIONS - //////////////////////////////////////////////////////////////////////////*/ - - /// @notice Creates a Lockup Linear stream + /// @dev Creates a Lockup Linear stream /// See https://docs.sablier.com/concepts/protocol/stream-types#lockup-linear - /// @dev See https://docs.sablier.com/contracts/v2/guides/create-stream/lockup-linear function _createLinearStream( IERC20 asset, uint128 totalAmount, @@ -208,9 +170,8 @@ abstract contract StreamManager is IStreamManager { streamId = LOCKUP_LINEAR.createWithTimestamps(params); } - /// @notice Creates a Lockup Tranched stream + /// @dev Creates a Lockup Tranched stream /// See https://docs.sablier.com/concepts/protocol/stream-types#unlock-monthly - /// @dev See https://docs.sablier.com/contracts/v2/guides/create-stream/lockup-linear function _createTranchedStream( IERC20 asset, uint128 totalAmount, @@ -267,25 +228,62 @@ abstract contract StreamManager is IStreamManager { streamId = LOCKUP_TRANCHED.createWithTimestamps(params); } - /// @dev Withdraws the maximum withdrawable amount from either a linear or tranched stream + /// @dev See the documentation in {ISablierV2Lockup-withdrawMax} + /// Notes: + /// - `streamType` parameter has been added to withdraw from the according {ISablierV2Lockup} contract function _withdrawStream( - ISablierV2Lockup sablier, + Types.Method streamType, uint256 streamId, address to ) internal returns (uint128 withdrawnAmount) { + // Set the according {ISablierV2Lockup} based on the stream type + ISablierV2Lockup sablier = _getISablierV2Lockup(streamType); + + // Withdraw the maximum withdrawable amount return sablier.withdrawMax(streamId, to); } - /// @dev Cancels the `streamId` stream - function _cancelStream(ISablierV2Lockup sablier, uint256 streamId) internal { + /// @dev Withdraws the maximum withdrawable amount and transfers the stream NFT to the new recipient + /// Notes: + /// - `streamType` parameter has been added to withdraw from the according {ISablierV2Lockup} contract + function _withdrawMaxAndTransferStream( + Types.Method streamType, + uint256 streamId, + address newRecipient + ) internal returns (uint128 withdrawnAmount) { + // Set the according {ISablierV2Lockup} based on the stream type + ISablierV2Lockup sablier = _getISablierV2Lockup(streamType); + + // Checks: the caller is the current recipient. This also checks that the NFT was not burned. + address currentRecipient = sablier.ownerOf(streamId); + + // Checks, Effects and Interactions: withdraw the maximum withdrawable amount + withdrawnAmount = sablier.withdrawMax(streamId, currentRecipient); + + // Interactions: transfer the stream to the new recipient + sablier.transferFrom({ from: msg.sender, to: newRecipient, tokenId: streamId }); + } + + /// @dev See the documentation in {ISablierV2Lockup-cancel} + /// + /// Notes: + /// - `msg.sender` must be the initial stream creator + function _cancelStream(Types.Method streamType, uint256 streamId) internal { + // Set the according {ISablierV2Lockup} based on the stream type + ISablierV2Lockup sablier = _getISablierV2Lockup(streamType); + // Checks: the `msg.sender` is the initial stream creator address initialSender = _initialStreamSender[streamId]; if (msg.sender != initialSender) revert Errors.OnlyInitialStreamSender(initialSender); - // Cancel the stream + // Checks, Effect, Interactions: cancel the stream sablier.cancel(streamId); } + /*////////////////////////////////////////////////////////////////////////// + OTHER INTERNAL FUNCTIONS + //////////////////////////////////////////////////////////////////////////*/ + /// @dev Transfers the `amount` of `asset` tokens to this address (or the contract inherting from) /// and approves either the `SablierV2LockupLinear` or `SablierV2LockupTranched` to spend the amount function _transferFromAndApprove(IERC20 asset, uint128 amount, address spender) internal { diff --git a/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol b/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol index d5269c5d..6ec2a41d 100644 --- a/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol +++ b/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol @@ -51,6 +51,22 @@ interface IStreamManager { /// @param streamId The ID of the stream to be retrieved function getTranchedStream(uint256 streamId) external view returns (LockupTranched.StreamLT memory stream); + /// @notice See the documentation in {ISablierV2Lockup-withdrawableAmountOf} + /// Notes: + /// - `streamType` parameter has been added to retrieve from the according {ISablierV2Lockup} contract + function withdrawableAmountOf( + Types.Method streamType, + uint256 streamId + ) external view returns (uint128 withdrawableAmount); + + /// @notice See the documentation in {ISablierV2Lockup-streamedAmountOf} + /// Notes: + /// - `streamType` parameter has been added to retrieve from the according {ISablierV2Lockup} contract + function streamedAmountOf( + Types.Method streamType, + uint256 streamId + ) external view returns (uint128 streamedAmount); + /*////////////////////////////////////////////////////////////////////////// NON-CONSTANT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ @@ -88,39 +104,9 @@ interface IStreamManager { /// @notice Updates the fee charged by the broker /// /// Notes: + /// - `msg.sender` must be the broker admin /// - The new fee will be applied only to the new streams hence it can't be retrospectively updated /// /// @param newBrokerFee The new broker fee function updateStreamBrokerFee(UD60x18 newBrokerFee) external; - - /// @notice See the documentation in {ISablierV2Lockup-withdrawMax} - /// Notes: - /// - `streamType` parameter has been added to withdraw from the according {ISablierV2Lockup} contract - function withdrawStream( - Types.Method streamType, - uint256 streamId, - address to - ) external returns (uint128 withdrawnAmount); - - /// @notice See the documentation in {ISablierV2Lockup-withdrawableAmountOf} - /// Notes: - /// - `streamType` parameter has been added to retrieve from the according {ISablierV2Lockup} contract - function withdrawableAmountOf( - Types.Method streamType, - uint256 streamId - ) external view returns (uint128 withdrawableAmount); - - /// @notice See the documentation in {ISablierV2Lockup-streamedAmountOf} - /// Notes: - /// - `streamType` parameter has been added to retrieve from the according {ISablierV2Lockup} contract - function streamedAmountOf( - Types.Method streamType, - uint256 streamId - ) external view returns (uint128 streamedAmount); - - /// @notice See the documentation in {ISablierV2Lockup-cancel} - /// - /// Notes: - /// - Reverts with {OnlyInitialStreamSender} if `msg.sender` is not the initial stream creator - function cancelStream(Types.Method streamType, uint256 streamId) external; } diff --git a/test/Base.t.sol b/test/Base.t.sol index 7be67285..c6f2917d 100644 --- a/test/Base.t.sol +++ b/test/Base.t.sol @@ -14,6 +14,7 @@ import { DockRegistry } from "./../src/DockRegistry.sol"; import { ERC1967Proxy } from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import { MockERC721Collection } from "./mocks/MockERC721Collection.sol"; import { MockERC1155Collection } from "./mocks/MockERC1155Collection.sol"; +import { MockBadContainer } from "./mocks/MockBadContainer.sol"; abstract contract Base_Test is Test, Events { /*////////////////////////////////////////////////////////////////////////// @@ -97,6 +98,25 @@ abstract contract Base_Test is Test, Events { vm.stopPrank(); } + /// @dev Deploys a new {MockBadContainer} contract based on the provided `owner`, `moduleKeeper` and `initialModules` input params + function deployBadContainer( + address _owner, + uint256 _dockId, + address[] memory _initialModules + ) internal returns (MockBadContainer _container) { + vm.startPrank({ msgSender: users.admin }); + for (uint256 i; i < _initialModules.length; ++i) { + allowlistModule(_initialModules[i]); + } + vm.stopPrank(); + + vm.prank({ msgSender: _owner }); + _container = MockBadContainer( + payable(dockRegistry.createContainer({ dockId: _dockId, initialModules: _initialModules })) + ); + vm.stopPrank(); + } + function allowlistModule(address _module) internal { moduleKeeper.addToAllowlist({ module: _module }); } diff --git a/test/integration/Integration.t.sol b/test/integration/Integration.t.sol index 6ac6ad94..25159832 100644 --- a/test/integration/Integration.t.sol +++ b/test/integration/Integration.t.sol @@ -7,6 +7,8 @@ import { SablierV2LockupLinear } from "@sablier/v2-core/src/SablierV2LockupLinea import { SablierV2LockupTranched } from "@sablier/v2-core/src/SablierV2LockupTranched.sol"; import { NFTDescriptorMock } from "@sablier/v2-core/test/mocks/NFTDescriptorMock.sol"; import { MockStreamManager } from "../mocks/MockStreamManager.sol"; +import { MockBadContainer } from "../mocks/MockBadContainer.sol"; +import { Container } from "./../../src/Container.sol"; abstract contract Integration_Test is Base_Test { /*////////////////////////////////////////////////////////////////////////// @@ -19,6 +21,7 @@ abstract contract Integration_Test is Base_Test { SablierV2LockupLinear internal sablierV2LockupLinear; SablierV2LockupTranched internal sablierV2LockupTranched; MockStreamManager internal mockStreamManager; + MockBadContainer internal badContainer; /*////////////////////////////////////////////////////////////////////////// SET-UP FUNCTION @@ -37,6 +40,9 @@ abstract contract Integration_Test is Base_Test { // Deploy the {Container} contract with the {InvoiceModule} enabled by default container = deployContainer({ _owner: users.eve, _dockId: 0, _initialModules: modules }); + // Deploy a "bad" {Container} with the `mockBadReceiver` as the owner + badContainer = deployBadContainer({ _owner: address(mockBadReceiver), _dockId: 0, _initialModules: modules }); + // Deploy the mock {StreamManager} mockStreamManager = new MockStreamManager(sablierV2LockupLinear, sablierV2LockupTranched, users.admin); @@ -44,6 +50,8 @@ abstract contract Integration_Test is Base_Test { vm.label({ account: address(invoiceModule), newLabel: "InvoiceModule" }); vm.label({ account: address(sablierV2LockupLinear), newLabel: "SablierV2LockupLinear" }); vm.label({ account: address(sablierV2LockupTranched), newLabel: "SablierV2LockupTranched" }); + vm.label({ account: address(container), newLabel: "Eve's Container" }); + vm.label({ account: address(badContainer), newLabel: "Bad receiver's Container" }); } /*////////////////////////////////////////////////////////////////////////// @@ -63,7 +71,8 @@ abstract contract Integration_Test is Base_Test { invoiceModule = new InvoiceModule({ _sablierLockupLinear: sablierV2LockupLinear, _sablierLockupTranched: sablierV2LockupTranched, - _brokerAdmin: users.admin + _brokerAdmin: users.admin, + _URI: "ipfs://CID/" }); } } diff --git a/test/integration/concrete/invoice-module/cancel-invoice/cancelInvoice.t.sol b/test/integration/concrete/invoice-module/cancel-invoice/cancelInvoice.t.sol index 0c122181..78838a23 100644 --- a/test/integration/concrete/invoice-module/cancel-invoice/cancelInvoice.t.sol +++ b/test/integration/concrete/invoice-module/cancel-invoice/cancelInvoice.t.sol @@ -35,8 +35,8 @@ contract CancelInvoice_Integration_Concret_Test is CancelInvoice_Integration_Sha // Set the one-off ETH transfer invoice as current one uint256 invoiceId = 2; - // Make Eve the caller who is the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); // Cancel the invoice first invoiceModule.cancelInvoice({ id: invoiceId }); @@ -77,8 +77,8 @@ contract CancelInvoice_Integration_Concret_Test is CancelInvoice_Integration_Sha // Set the one-off ETH transfer invoice as current one uint256 invoiceId = 2; - // Make Eve the caller who is the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); // Expect the {InvoiceCanceled} event to be emitted vm.expectEmit(); @@ -123,8 +123,8 @@ contract CancelInvoice_Integration_Concret_Test is CancelInvoice_Integration_Sha // Set current invoice as a linear stream-based one uint256 invoiceId = 5; - // Make Eve the caller who is the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); // Expect the {InvoiceCanceled} event to be emitted vm.expectEmit(); @@ -232,8 +232,8 @@ contract CancelInvoice_Integration_Concret_Test is CancelInvoice_Integration_Sha // Set current invoice as a tranched stream-based one uint256 invoiceId = 5; - // Make Eve the caller who is the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); // Expect the {InvoiceCanceled} event to be emitted vm.expectEmit(); diff --git a/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol b/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol index 55bb1a20..22d735a3 100644 --- a/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol +++ b/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol @@ -21,7 +21,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.expectRevert(Errors.ContainerZeroCodeSize.selector); // Create an one-off transfer invoice - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Run the test invoiceModule.createInvoice(invoice); @@ -32,11 +32,11 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the call to revert with the {ContainerUnsupportedInterface} error @@ -51,14 +51,14 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Set the payment amount to zero to simulate the error invoice.payment.amount = 0; // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the call to revert with the {ZeroPaymentAmount} error @@ -78,7 +78,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Set the start time to be the current timestamp and the end time one second earlier invoice.startTime = uint40(block.timestamp); @@ -86,7 +86,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the call to revert with the {StartTimeGreaterThanEndTime} error @@ -107,7 +107,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Set the block.timestamp to 1641070800 vm.warp(1_641_070_800); @@ -119,7 +119,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the call to revert with the {EndTimeInThePast} error @@ -143,18 +143,18 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create a recurring transfer invoice that must be paid on a monthly basis // Hence, the interval between the start and end time must be at least 1 month - invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the module call to emit an {InvoiceCreated} event vm.expectEmit(); emit Events.InvoiceCreated({ id: 1, - recipient: users.eve, + recipient: address(container), status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -170,7 +170,9 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Assert the actual and expected invoice state Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); - assertEq(actualInvoice.recipient, users.eve); + address expectedRecipient = invoiceModule.ownerOf(1); + + assertEq(expectedRecipient, address(container)); assertEq(uint8(actualInvoice.status), uint8(Types.Status.Pending)); assertEq(actualInvoice.startTime, invoice.startTime); assertEq(actualInvoice.endTime, invoice.endTime); @@ -196,14 +198,14 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create a recurring transfer invoice that must be paid on a monthly basis // Hence, the interval between the start and end time must be at least 1 month - invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Monthly, recipient: users.eve }); + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Monthly }); // Alter the end time to be 3 weeks from now invoice.endTime = uint40(block.timestamp) + 3 weeks; // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the call to revert with the {PaymentIntervalTooShortForSelectedRecurrence} error @@ -227,18 +229,18 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a recurring transfer invoice that must be paid on weekly basis - invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the module call to emit an {InvoiceCreated} event vm.expectEmit(); emit Events.InvoiceCreated({ id: 1, - recipient: users.eve, + recipient: address(container), status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -254,7 +256,9 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Assert the actual and expected invoice state Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); - assertEq(actualInvoice.recipient, users.eve); + address expectedRecipient = invoiceModule.ownerOf(1); + + assertEq(expectedRecipient, address(container)); assertEq(uint8(actualInvoice.status), uint8(Types.Status.Pending)); assertEq(actualInvoice.startTime, invoice.startTime); assertEq(actualInvoice.endTime, invoice.endTime); @@ -279,7 +283,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); // Alter the payment recurrence by setting it to one-off invoice.payment.recurrence = Types.Recurrence.OneOff; @@ -289,7 +293,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Run the test @@ -310,7 +314,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Monthly, recipient: users.eve }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Monthly }); // Alter the end time to be 3 weeks from now invoice.endTime = uint40(block.timestamp) + 3 weeks; @@ -320,7 +324,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Run the test @@ -342,7 +346,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); // Alter the payment asset by setting it to invoice.payment.asset = address(0); @@ -352,7 +356,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Run the test @@ -373,18 +377,18 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the module call to emit an {InvoiceCreated} event vm.expectEmit(); emit Events.InvoiceCreated({ id: 1, - recipient: users.eve, + recipient: address(container), status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -400,7 +404,9 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Assert the actual and expected invoice state Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); - assertEq(actualInvoice.recipient, users.eve); + address expectedRecipient = invoiceModule.ownerOf(1); + + assertEq(expectedRecipient, address(container)); assertEq(uint8(actualInvoice.status), uint8(Types.Status.Pending)); assertEq(actualInvoice.startTime, invoice.startTime); assertEq(actualInvoice.endTime, invoice.endTime); @@ -425,7 +431,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - invoice = createInvoiceWithLinearStream({ recipient: users.eve }); + invoice = createInvoiceWithLinearStream(); // Alter the payment asset by setting it to invoice.payment.asset = address(0); @@ -435,7 +441,7 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Run the test @@ -456,18 +462,18 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - invoice = createInvoiceWithLinearStream({ recipient: users.eve }); + invoice = createInvoiceWithLinearStream(); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the module call to emit an {InvoiceCreated} event vm.expectEmit(); emit Events.InvoiceCreated({ id: 1, - recipient: users.eve, + recipient: address(container), status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -483,7 +489,9 @@ contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Sha // Assert the actual and expected invoice state Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); - assertEq(actualInvoice.recipient, users.eve); + address expectedRecipient = invoiceModule.ownerOf(1); + + assertEq(expectedRecipient, address(container)); assertEq(uint8(actualInvoice.status), uint8(Types.Status.Pending)); assertEq(actualInvoice.startTime, invoice.startTime); assertEq(actualInvoice.endTime, invoice.endTime); diff --git a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol index e990c829..854ab34c 100644 --- a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol +++ b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol @@ -11,12 +11,11 @@ import { LockupLinear, LockupTranched } from "@sablier/v2-core/src/types/DataTyp contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Test { function setUp() public virtual override { PayInvoice_Integration_Shared_Test.setUp(); - createMockInvoices(); } function test_RevertWhen_InvoiceNull() external { - // Expect the call to revert with the {InvoiceNull} error - vm.expectRevert(Errors.InvoiceNull.selector); + // Expect the call to revert with the {ERC721NonexistentToken} error + vm.expectRevert(abi.encodeWithSelector(Errors.ERC721NonexistentToken.selector, 99)); // Run the test invoiceModule.payInvoice({ id: 99 }); @@ -46,8 +45,8 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Set the one-off USDT transfer invoice as current one uint256 invoiceId = 1; - // Make Eve the caller in this test suite as she's the owner of the {Container} contract - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller in this test suite as his container is the owner of the invoice + vm.startPrank({ msgSender: address(container) }); // Cancel the invoice first invoiceModule.cancelInvoice({ id: invoiceId }); @@ -96,19 +95,29 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te givenPaymentAmountInNativeToken whenPaymentAmountEqualToInvoiceValue { - // Create a mock invoice with a one-off ETH transfer and set {MockBadReceiver} as the recipient - Types.Invoice memory invoice = - createInvoiceWithOneOffTransfer({ asset: address(0), recipient: address(mockBadReceiver) }); + // Create a mock invoice with a one-off ETH transfer from the Eve's container + Types.Invoice memory invoice = createInvoiceWithOneOffTransfer({ asset: address(0) }); executeCreateInvoice({ invoice: invoice, user: users.eve }); - // Make {MockBadReceiver} the payer for this invoice + uint256 invoiceId = _nextInvoiceId; + + // Make Eve's container the caller for the next call to approve & transfer the invoice NFT to a bad receiver + vm.startPrank({ msgSender: address(container) }); + + // Approve the {InvoiceModule} to transfer the token + invoiceModule.approve({ to: address(invoiceModule), tokenId: invoiceId }); + + // Transfer the invoice to a bad receiver so we can test against `NativeTokenPaymentFailed` + invoiceModule.transferFrom({ from: address(container), to: address(mockBadReceiver), tokenId: invoiceId }); + + // Make Bob the payer for this invoice vm.startPrank({ msgSender: users.bob }); // Expect the call to be reverted with the {NativeTokenPaymentFailed} error vm.expectRevert(Errors.NativeTokenPaymentFailed.selector); // Run the test - invoiceModule.payInvoice{ value: invoice.payment.amount }({ id: 6 }); + invoiceModule.payInvoice{ value: invoice.payment.amount }({ id: invoiceId }); } function test_PayInvoice_PaymentMethodTransfer_NativeToken_OneOff() @@ -129,7 +138,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Store the ETH balances of Bob and recipient before paying the invoice uint256 balanceOfBobBefore = address(users.bob).balance; - uint256 balanceOfRecipientBefore = address(invoices[invoiceId].recipient).balance; + uint256 balanceOfRecipientBefore = address(container).balance; // Expect the {InvoicePaid} event to be emitted vm.expectEmit(); @@ -157,10 +166,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Assert the balances of payer and recipient assertEq(address(users.bob).balance, balanceOfBobBefore - invoices[invoiceId].payment.amount); - assertEq( - address(invoices[invoiceId].recipient).balance, - balanceOfRecipientBefore + invoices[invoiceId].payment.amount - ); + assertEq(address(container).balance, balanceOfRecipientBefore + invoices[invoiceId].payment.amount); } function test_PayInvoice_PaymentMethodTransfer_ERC20Token_Recurring() @@ -180,7 +186,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Store the USDT balances of Bob and recipient before paying the invoice uint256 balanceOfBobBefore = usdt.balanceOf(users.bob); - uint256 balanceOfRecipientBefore = usdt.balanceOf(invoices[invoiceId].recipient); + uint256 balanceOfRecipientBefore = usdt.balanceOf(address(container)); // Approve the {InvoiceModule} to transfer the ERC-20 tokens on Bob's behalf usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); @@ -211,9 +217,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Assert the balances of payer and recipient assertEq(usdt.balanceOf(users.bob), balanceOfBobBefore - invoices[invoiceId].payment.amount); - assertEq( - usdt.balanceOf(invoices[invoiceId].recipient), balanceOfRecipientBefore + invoices[invoiceId].payment.amount - ); + assertEq(usdt.balanceOf(address(container)), balanceOfRecipientBefore + invoices[invoiceId].payment.amount); } function test_PayInvoice_PaymentMethodLinearStream() @@ -262,7 +266,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Assert the actual and the expected state of the Sablier v2 linear stream LockupLinear.StreamLL memory stream = invoiceModule.getLinearStream({ streamId: 1 }); assertEq(stream.sender, address(invoiceModule)); - assertEq(stream.recipient, users.eve); + assertEq(stream.recipient, address(container)); assertEq(address(stream.asset), address(usdt)); assertEq(stream.startTime, invoice.startTime); assertEq(stream.endTime, invoice.endTime); @@ -314,7 +318,7 @@ contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Te // Assert the actual and the expected state of the Sablier v2 tranched stream LockupTranched.StreamLT memory stream = invoiceModule.getTranchedStream({ streamId: 1 }); assertEq(stream.sender, address(invoiceModule)); - assertEq(stream.recipient, users.eve); + assertEq(stream.recipient, address(container)); assertEq(address(stream.asset), address(usdt)); assertEq(stream.startTime, invoice.startTime); assertEq(stream.endTime, invoice.endTime); diff --git a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree index 084ff6dc..3369cad8 100644 --- a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree +++ b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree @@ -1,6 +1,6 @@ payInvoice.t.sol -├── when the invoice IS null -│ └── it should revert with the {InvoiceNull} error +├── when the invoice IS null (there is no ERC-721 token minted) +│ └── it should revert with the {ERC721NonexistentToken} error └── when the invoice IS NOT null ├── when the invoice IS already paid │ └── it should revert with the {InvoiceAlreadyPaid} error diff --git a/test/integration/concrete/invoice-module/transfer-from/transferFrom.t.sol b/test/integration/concrete/invoice-module/transfer-from/transferFrom.t.sol new file mode 100644 index 00000000..e4def1d6 --- /dev/null +++ b/test/integration/concrete/invoice-module/transfer-from/transferFrom.t.sol @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.26; + +import { TransferFrom_Integration_Shared_Test } from "../../../shared/transferFrom.t.sol"; +import { Errors } from "../../../../utils/Errors.sol"; +import { Events } from "../../../../utils/Events.sol"; +import { Types } from "./../../../../../src/modules/invoice-module/libraries/Types.sol"; + +contract TransferFrom_Integration_Concret_Test is TransferFrom_Integration_Shared_Test { + function setUp() public virtual override { + TransferFrom_Integration_Shared_Test.setUp(); + } + + function test_RevertWhen_TokenDoesNotExist() external { + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); + + // Expect the call to revert with the {ERC721NonexistentToken} error + vm.expectRevert(abi.encodeWithSelector(Errors.ERC721NonexistentToken.selector, 99)); + + // Run the test + invoiceModule.transferFrom({ from: address(container), to: users.eve, tokenId: 99 }); + } + + function test_TransferFrom_PaymentMethodStream() external whenTokenExists { + uint256 invoiceId = 4; + uint256 streamId = 1; + + // Make Bob the payer for the invoice + vm.startPrank({ msgSender: users.bob }); + + // Approve the {InvoiceModule} to transfer the USDT tokens on Bob's behalf + usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); + + // Pay the invoice + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount }({ id: invoiceId }); + + // Simulate the passage of time so that the maximum withdrawable amount is non-zero + vm.warp(block.timestamp + 5 weeks); + + // Store Eve's container balance before withdrawing the USDT tokens + uint256 balanceOfBefore = usdt.balanceOf(address(container)); + + // Get the maximum withdrawable amount from the stream before transferring the stream NFT + uint128 maxWithdrawableAmount = + invoiceModule.withdrawableAmountOf({ streamType: Types.Method.LinearStream, streamId: streamId }); + + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); + + // Approve the {InvoiceModule} to transfer the `streamId` stream on behalf of the Eve's container + sablierV2LockupLinear.approve({ to: address(invoiceModule), tokenId: streamId }); + + // Run the test + invoiceModule.transferFrom({ from: address(container), to: users.eve, tokenId: invoiceId }); + + // Assert the current and expected Eve's container USDT balance + assertEq(balanceOfBefore + maxWithdrawableAmount, usdt.balanceOf(address(container))); + + // Assert the current and expected owner of the invoice NFT + assertEq(invoiceModule.ownerOf({ tokenId: invoiceId }), users.eve); + + // Assert the current and expected owner of the invoice stream NFT + assertEq(sablierV2LockupLinear.ownerOf({ tokenId: streamId }), users.eve); + } + + function test_TransferFrom_PaymentTransfer() external whenTokenExists { + uint256 invoiceId = 1; + + // Make Eve's container the caller which is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); + + // Run the test + invoiceModule.transferFrom({ from: address(container), to: users.eve, tokenId: invoiceId }); + + // Assert the current and expected owner of the invoice NFT + assertEq(invoiceModule.ownerOf({ tokenId: invoiceId }), users.eve); + } +} diff --git a/test/integration/concrete/invoice-module/transfer-from/transferFrom.tree b/test/integration/concrete/invoice-module/transfer-from/transferFrom.tree new file mode 100644 index 00000000..93f9779e --- /dev/null +++ b/test/integration/concrete/invoice-module/transfer-from/transferFrom.tree @@ -0,0 +1,10 @@ +transferFrom.t.sol +├── when the token does not exist +│ └── it should revert with the {ERC721NonexistentToken} error +└── when the token exist + ├── when the payment method is stream-based + │ ├── it should withdraw the maximum withdrawable amount of the Sablier stream + │ ├── it should transfer the Sablier stream NFT + │ └── it should transfer the invoice NFT + └── when the payment is transfer-based + └── it should transfer the invoice NFT \ No newline at end of file diff --git a/test/integration/concrete/invoice-module/withdraw-stream/withdrawStream.sol b/test/integration/concrete/invoice-module/withdraw-invoice-stream/withdrawStream.t.sol similarity index 77% rename from test/integration/concrete/invoice-module/withdraw-stream/withdrawStream.sol rename to test/integration/concrete/invoice-module/withdraw-invoice-stream/withdrawStream.t.sol index 89af4f5f..e7a96871 100644 --- a/test/integration/concrete/invoice-module/withdraw-stream/withdrawStream.sol +++ b/test/integration/concrete/invoice-module/withdraw-invoice-stream/withdrawStream.t.sol @@ -14,7 +14,7 @@ contract WithdrawLinearStream_Integration_Concret_Test is WithdrawLinearStream_I uint256 invoiceId = 4; uint256 streamId = 1; - // The invoice must be paid for its status to be updated to `Ongoing` + // The invoice must be paid in order to update its status to `Ongoing` // Make Bob the payer of the invoice (also Bob will be the initial stream sender) vm.startPrank({ msgSender: users.bob }); @@ -27,21 +27,21 @@ contract WithdrawLinearStream_Integration_Concret_Test is WithdrawLinearStream_I // Advance the timestamp by 5 weeks to simulate the withdrawal vm.warp(block.timestamp + 5 weeks); - // Store Eve's balance before withdrawing the USDT tokens - uint256 balanceOfBefore = usdt.balanceOf(users.eve); + // Store Eve's container balance before withdrawing the USDT tokens + uint256 balanceOfBefore = usdt.balanceOf(address(container)); // Get the maximum withdrawable amount from the stream uint128 maxWithdrawableAmount = invoiceModule.withdrawableAmountOf({ streamType: Types.Method.LinearStream, streamId: streamId }); - // Make Eve the caller in this test suite as she's the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller in this test suite as his container is the recipient of the invoice + vm.startPrank({ msgSender: address(container) }); // Run the test - invoiceModule.withdrawStream({ streamType: Types.Method.LinearStream, streamId: streamId, to: users.eve }); + invoiceModule.withdrawInvoiceStream(invoiceId); // Assert the current and expected USDT balance of Eve - assertEq(balanceOfBefore + maxWithdrawableAmount, usdt.balanceOf(users.eve)); + assertEq(balanceOfBefore + maxWithdrawableAmount, usdt.balanceOf(address(container))); } function test_WithdrawStream_TranchedStream() external givenPaymentMethodTranchedStream givenInvoiceStatusOngoing { @@ -62,20 +62,20 @@ contract WithdrawLinearStream_Integration_Concret_Test is WithdrawLinearStream_I // Advance the timestamp by 5 weeks to simulate the withdrawal vm.warp(block.timestamp + 5 weeks); - // Store Eve's balance before withdrawing the USDT tokens - uint256 balanceOfBefore = usdt.balanceOf(users.eve); + // Store Eve's container balance before withdrawing the USDT tokens + uint256 balanceOfBefore = usdt.balanceOf(address(container)); // Get the maximum withdrawable amount from the stream uint128 maxWithdrawableAmount = invoiceModule.withdrawableAmountOf({ streamType: Types.Method.TranchedStream, streamId: streamId }); - // Make Eve the caller in this test suite as she's the recipient of the invoice - vm.startPrank({ msgSender: users.eve }); + // Make Eve's container the caller in this test suite as her container is the owner of the invoice + vm.startPrank({ msgSender: address(container) }); // Run the test invoiceModule.withdrawInvoiceStream(invoiceId); - // Assert the current and expected USDT balance of Eve - assertEq(balanceOfBefore + maxWithdrawableAmount, usdt.balanceOf(users.eve)); + // Assert the current and expected USDT balance of Eve's container + assertEq(balanceOfBefore + maxWithdrawableAmount, usdt.balanceOf(address(container))); } } diff --git a/test/integration/concrete/invoice-module/withdraw-stream/withdrawStream.tree b/test/integration/concrete/invoice-module/withdraw-invoice-stream/withdrawStream.tree similarity index 100% rename from test/integration/concrete/invoice-module/withdraw-stream/withdrawStream.tree rename to test/integration/concrete/invoice-module/withdraw-invoice-stream/withdrawStream.tree diff --git a/test/integration/fuzz/createInvoice.t.sol b/test/integration/fuzz/createInvoice.t.sol index ec8df4c3..11911b10 100644 --- a/test/integration/fuzz/createInvoice.t.sol +++ b/test/integration/fuzz/createInvoice.t.sol @@ -19,7 +19,6 @@ contract CreateInvoice_Integration_Fuzz_Test is CreateInvoice_Integration_Shared function testFuzz_CreateInvoice( uint8 recurrence, uint8 paymentMethod, - address recipient, uint40 startTime, uint40 endTime, uint128 amount @@ -37,7 +36,6 @@ contract CreateInvoice_Integration_Fuzz_Test is CreateInvoice_Integration_Shared vm.assume(recurrence < 4); // Assume recurrence is within Types.Method enum values (Transfer, LinearStream, TranchedStream) (0, 1, 2) vm.assume(paymentMethod < 3); - vm.assume(recipient != address(0) && recipient != address(this)); vm.assume(startTime >= uint40(block.timestamp) && startTime < endTime); vm.assume(amount > 0); @@ -48,7 +46,6 @@ contract CreateInvoice_Integration_Fuzz_Test is CreateInvoice_Integration_Shared // Create a new invoice with a transfer-based payment invoice = Types.Invoice({ - recipient: recipient, status: Types.Status.Pending, startTime: startTime, endTime: endTime, @@ -64,14 +61,14 @@ contract CreateInvoice_Integration_Fuzz_Test is CreateInvoice_Integration_Shared // Create the calldata for the {InvoiceModule} execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); // Expect the module call to emit an {InvoiceCreated} event vm.expectEmit(); emit Events.InvoiceCreated({ id: 1, - recipient: invoice.recipient, + recipient: address(container), status: Types.Status.Pending, startTime: invoice.startTime, endTime: invoice.endTime, @@ -87,7 +84,9 @@ contract CreateInvoice_Integration_Fuzz_Test is CreateInvoice_Integration_Shared // Assert the actual and expected invoice state Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); - assertEq(actualInvoice.recipient, invoice.recipient); + address actualRecipient = invoiceModule.ownerOf(1); + + assertEq(actualRecipient, address(container)); assertEq(uint8(actualInvoice.status), uint8(Types.Status.Pending)); assertEq(actualInvoice.startTime, invoice.startTime); assertEq(actualInvoice.endTime, invoice.endTime); diff --git a/test/integration/fuzz/payInvoice.t.sol b/test/integration/fuzz/payInvoice.t.sol index 2e95eee3..5aefd327 100644 --- a/test/integration/fuzz/payInvoice.t.sol +++ b/test/integration/fuzz/payInvoice.t.sol @@ -42,9 +42,8 @@ contract PayInvoice_Integration_Fuzz_Test is PayInvoice_Integration_Shared_Test Helpers.checkFuzzedPaymentMethod(paymentMethod, recurrence, startTime, endTime); if (!valid) return; - // Create a new invoice with a transfer-based payment + // Create a new invoice with the fuzzed payment method invoice = Types.Invoice({ - recipient: users.eve, status: Types.Status.Pending, startTime: startTime, endTime: endTime, @@ -60,10 +59,12 @@ contract PayInvoice_Integration_Fuzz_Test is PayInvoice_Integration_Shared_Test // Create the calldata for the {InvoiceModule} execution bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); - // Make Eve the caller to create the fuzzed invoice + uint256 invoiceId = _nextInvoiceId; + + // Make Eve the caller to create the fuzzed invoice vm.startPrank({ msgSender: users.eve }); // Create the fuzzed invoice @@ -80,7 +81,7 @@ contract PayInvoice_Integration_Fuzz_Test is PayInvoice_Integration_Shared_Test // Store the USDT balances of the payer and recipient before paying the invoice uint256 balanceOfPayerBefore = usdt.balanceOf(users.bob); - uint256 balanceOfRecipientBefore = usdt.balanceOf(users.eve); + uint256 balanceOfRecipientBefore = usdt.balanceOf(address(container)); uint256 streamId = paymentMethod == 0 ? 0 : 1; numberOfPayments = numberOfPayments > 0 ? numberOfPayments - 1 : 0; @@ -92,7 +93,7 @@ contract PayInvoice_Integration_Fuzz_Test is PayInvoice_Integration_Shared_Test // Expect the {InvoicePaid} event to be emitted vm.expectEmit(); emit Events.InvoicePaid({ - id: 1, + id: invoiceId, payer: users.bob, status: expectedInvoiceStatus, payment: Types.Payment({ @@ -106,19 +107,19 @@ contract PayInvoice_Integration_Fuzz_Test is PayInvoice_Integration_Shared_Test }); // Run the test - invoiceModule.payInvoice({ id: 1 }); + invoiceModule.payInvoice({ id: invoiceId }); // Assert the actual and the expected state of the invoice - Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: 1 }); + Types.Invoice memory actualInvoice = invoiceModule.getInvoice({ id: invoiceId }); assertEq(uint8(actualInvoice.status), uint8(expectedInvoiceStatus)); assertEq(actualInvoice.payment.paymentsLeft, numberOfPayments); // Assert the actual and expected balances of the payer and recipient assertEq(usdt.balanceOf(users.bob), balanceOfPayerBefore - invoice.payment.amount); if (invoice.payment.method == Types.Method.Transfer) { - assertEq(usdt.balanceOf(users.eve), balanceOfRecipientBefore + invoice.payment.amount); + assertEq(usdt.balanceOf(address(container)), balanceOfRecipientBefore + invoice.payment.amount); } else { - assertEq(usdt.balanceOf(users.eve), balanceOfRecipientBefore); + assertEq(usdt.balanceOf(address(container)), balanceOfRecipientBefore); } } } diff --git a/test/integration/shared/createInvoice.t.sol b/test/integration/shared/createInvoice.t.sol index cdfd1514..afb4994f 100644 --- a/test/integration/shared/createInvoice.t.sol +++ b/test/integration/shared/createInvoice.t.sol @@ -3,12 +3,45 @@ pragma solidity ^0.8.26; import { Integration_Test } from "../Integration.t.sol"; import { Types } from "./../../../src/modules/invoice-module/libraries/Types.sol"; +import { IContainer } from "./../../../src/interfaces/IContainer.sol"; abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { + mapping(uint256 invoiceId => Types.Invoice) invoices; + uint256 public _nextInvoiceId; + function setUp() public virtual override { Integration_Test.setUp(); } + function createMockInvoices() internal { + // Create a mock invoice with a one-off USDT transfer + Types.Invoice memory invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt) }); + invoices[1] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a one-off ETH transfer + invoice = createInvoiceWithOneOffTransfer({ asset: address(0) }); + invoices[2] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a recurring USDT transfer + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly }); + invoices[3] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a linear stream payment + invoice = createInvoiceWithLinearStream(); + invoices[4] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a tranched stream payment + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); + invoices[5] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + _nextInvoiceId = 6; + } + modifier whenCallerContract() { _; } @@ -58,15 +91,8 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a one-off transfer payment - function createInvoiceWithOneOffTransfer( - address asset, - address recipient - ) internal view returns (Types.Invoice memory invoice) { - invoice.recipient = recipient; - invoice.status = Types.Status.Pending; - - invoice.startTime = uint40(block.timestamp); - invoice.endTime = uint40(block.timestamp) + 4 weeks; + function createInvoiceWithOneOffTransfer(address asset) internal view returns (Types.Invoice memory invoice) { + invoice = _createInvoice(uint40(block.timestamp), uint40(block.timestamp) + 4 weeks); invoice.payment = Types.Payment({ method: Types.Method.Transfer, @@ -79,15 +105,12 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a recurring transfer payment - function createInvoiceWithRecurringTransfer( - Types.Recurrence recurrence, - address recipient - ) internal view returns (Types.Invoice memory invoice) { - invoice.recipient = recipient; - invoice.status = Types.Status.Pending; - - invoice.startTime = uint40(block.timestamp); - invoice.endTime = uint40(block.timestamp) + 4 weeks; + function createInvoiceWithRecurringTransfer(Types.Recurrence recurrence) + internal + view + returns (Types.Invoice memory invoice) + { + invoice = _createInvoice(uint40(block.timestamp), uint40(block.timestamp) + 4 weeks); invoice.payment = Types.Payment({ method: Types.Method.Transfer, @@ -100,12 +123,8 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a linear stream-based payment - function createInvoiceWithLinearStream(address recipient) internal view returns (Types.Invoice memory invoice) { - invoice.recipient = recipient; - invoice.status = Types.Status.Pending; - - invoice.startTime = uint40(block.timestamp); - invoice.endTime = uint40(block.timestamp) + 4 weeks; + function createInvoiceWithLinearStream() internal view returns (Types.Invoice memory invoice) { + invoice = _createInvoice(uint40(block.timestamp), uint40(block.timestamp) + 4 weeks); invoice.payment = Types.Payment({ method: Types.Method.LinearStream, @@ -118,15 +137,12 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a tranched stream-based payment - function createInvoiceWithTranchedStream( - Types.Recurrence recurrence, - address recipient - ) internal view returns (Types.Invoice memory invoice) { - invoice.recipient = recipient; - invoice.status = Types.Status.Pending; - - invoice.startTime = uint40(block.timestamp); - invoice.endTime = uint40(block.timestamp) + 4 weeks; + function createInvoiceWithTranchedStream(Types.Recurrence recurrence) + internal + view + returns (Types.Invoice memory invoice) + { + invoice = _createInvoice(uint40(block.timestamp), uint40(block.timestamp) + 4 weeks); invoice.payment = Types.Payment({ method: Types.Method.TranchedStream, @@ -142,16 +158,11 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { function createFuzzedInvoice( Types.Method method, Types.Recurrence recurrence, - address recipient, uint40 startTime, uint40 endTime, uint128 amount ) internal view returns (Types.Invoice memory invoice) { - invoice.recipient = recipient; - invoice.status = Types.Status.Pending; - - invoice.startTime = startTime; - invoice.endTime = endTime; + invoice = _createInvoice(startTime, endTime); invoice.payment = Types.Payment({ method: method, @@ -167,13 +178,27 @@ abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { // Make the `user` account the caller who must be the owner of the {Container} contract vm.startPrank({ msgSender: user }); + // Select the according {Container} of the user + IContainer _container; + if (user == users.eve) { + _container = container; + } else { + _container = badContainer; + } + // Create the invoice bytes memory data = abi.encodeWithSignature( - "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice + "createInvoice((uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", invoice ); - container.execute({ module: address(invoiceModule), value: 0, data: data }); + _container.execute({ module: address(invoiceModule), value: 0, data: data }); // Stop the active prank vm.stopPrank(); } + + function _createInvoice(uint40 startTime, uint40 endTime) internal pure returns (Types.Invoice memory invoice) { + invoice.status = Types.Status.Pending; + invoice.startTime = startTime; + invoice.endTime = endTime; + } } diff --git a/test/integration/shared/payInvoice.t.sol b/test/integration/shared/payInvoice.t.sol index 301798ff..f13997dd 100644 --- a/test/integration/shared/payInvoice.t.sol +++ b/test/integration/shared/payInvoice.t.sol @@ -3,40 +3,11 @@ pragma solidity ^0.8.26; import { Integration_Test } from "../Integration.t.sol"; import { CreateInvoice_Integration_Shared_Test } from "./createInvoice.t.sol"; -import { Types } from "./../../../src/modules/invoice-module/libraries/Types.sol"; abstract contract PayInvoice_Integration_Shared_Test is Integration_Test, CreateInvoice_Integration_Shared_Test { - mapping(uint256 invoiceId => Types.Invoice) invoices; - function setUp() public virtual override(Integration_Test, CreateInvoice_Integration_Shared_Test) { CreateInvoice_Integration_Shared_Test.setUp(); - } - - function createMockInvoices() internal { - // Create a mock invoice with a one-off USDT transfer - Types.Invoice memory invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); - invoices[1] = invoice; - executeCreateInvoice({ invoice: invoice, user: users.eve }); - - // Create a mock invoice with a one-off ETH transfer - invoice = createInvoiceWithOneOffTransfer({ asset: address(0), recipient: users.eve }); - invoices[2] = invoice; - executeCreateInvoice({ invoice: invoice, user: users.eve }); - - // Create a mock invoice with a recurring USDT transfer - invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); - invoices[3] = invoice; - executeCreateInvoice({ invoice: invoice, user: users.eve }); - - // Create a mock invoice with a linear stream payment - invoice = createInvoiceWithLinearStream({ recipient: users.eve }); - invoices[4] = invoice; - executeCreateInvoice({ invoice: invoice, user: users.eve }); - - // Create a mock invoice with a tranched stream payment - invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); - invoices[5] = invoice; - executeCreateInvoice({ invoice: invoice, user: users.eve }); + createMockInvoices(); } modifier whenInvoiceNotNull() { diff --git a/test/integration/shared/transferFrom.t.sol b/test/integration/shared/transferFrom.t.sol new file mode 100644 index 00000000..16d2a05a --- /dev/null +++ b/test/integration/shared/transferFrom.t.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.26; + +import { Integration_Test } from "../Integration.t.sol"; +import { PayInvoice_Integration_Shared_Test } from "./payInvoice.t.sol"; + +abstract contract TransferFrom_Integration_Shared_Test is Integration_Test, PayInvoice_Integration_Shared_Test { + function setUp() public virtual override(Integration_Test, PayInvoice_Integration_Shared_Test) { + PayInvoice_Integration_Shared_Test.setUp(); + } + + modifier whenTokenExists() { + _; + } +} diff --git a/test/mocks/MockBadContainer.sol b/test/mocks/MockBadContainer.sol new file mode 100644 index 00000000..ebc210d6 --- /dev/null +++ b/test/mocks/MockBadContainer.sol @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.26; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; +import { IERC721 } from "@openzeppelin/contracts/token/ERC721/IERC721.sol"; +import { IERC1155 } from "@openzeppelin/contracts/token/ERC1155/IERC1155.sol"; +import { IERC721Receiver } from "@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol"; +import { IERC1155Receiver } from "@openzeppelin/contracts/token/ERC1155/IERC1155Receiver.sol"; +import { ExcessivelySafeCall } from "@nomad-xyz/excessively-safe-call/src/ExcessivelySafeCall.sol"; + +import { IContainer } from "./../../src/interfaces/IContainer.sol"; +import { ModuleManager } from "./../../src/abstracts/ModuleManager.sol"; +import { IModuleManager } from "./../../src/interfaces/IModuleManager.sol"; +import { Errors } from "./../../src/libraries/Errors.sol"; +import { ModuleKeeper } from "./../../src/ModuleKeeper.sol"; +import { DockRegistry } from "./../../src/DockRegistry.sol"; + +/// @title MockBadContainer +/// @notice Container that reverts when receiving native tokens (ETH) +contract MockBadContainer is IContainer, ModuleManager { + using SafeERC20 for IERC20; + using ExcessivelySafeCall for address; + + /*////////////////////////////////////////////////////////////////////////// + CONSTRUCTOR + //////////////////////////////////////////////////////////////////////////*/ + + /// @dev Initializes the address of the {Container} owner, {ModuleKeeper} and enables the initial module(s) + constructor( + DockRegistry _dockRegistry, + address[] memory _initialModules + ) ModuleManager(_dockRegistry, _initialModules) { + dockRegistry = _dockRegistry; + } + + /*////////////////////////////////////////////////////////////////////////// + RECEIVE & FALLBACK + //////////////////////////////////////////////////////////////////////////*/ + + /// @dev Revert on receiving ETH + receive() external payable { + revert(); + } + + /// @dev Fallback function to handle incoming calls with data + fallback() external payable { + revert(); + } + + /*////////////////////////////////////////////////////////////////////////// + MODIFIERS + //////////////////////////////////////////////////////////////////////////*/ + + /// @notice Reverts if the `msg.sender` is not the owner of the {Container} assigned in the registry + modifier onlyOwner() { + if (msg.sender != dockRegistry.ownerOfContainer(address(this))) revert Errors.CallerNotContainerOwner(); + _; + } + + /*////////////////////////////////////////////////////////////////////////// + NON-CONSTANT FUNCTIONS + //////////////////////////////////////////////////////////////////////////*/ + + /// @inheritdoc IContainer + function execute( + address module, + uint256 value, + bytes memory data + ) public onlyOwner onlyEnabledModule(module) returns (bool success) { + // Allocate all the gas to the executed module method + uint256 txGas = gasleft(); + + // Execute the call via assembly and get only the first 4 bytes of the returndata + // which will be the selector of the error in case of a revert in the module contract + // See https://github.com/nomad-xyz/ExcessivelySafeCall + bytes memory result; + (success, result) = module.excessivelySafeCall({ _gas: txGas, _value: 0, _maxCopy: 4, _calldata: data }); + + // Revert with the same error returned by the module contract if the call failed + if (!success) { + assembly { + revert(add(result, 0x20), 4) + } + // Otherwise log the execution success + } else { + emit ModuleExecutionSucceded(module, value, data); + } + } + + /// @inheritdoc IContainer + function withdrawERC20(IERC20 asset, uint256 amount) public onlyOwner { + // Checks: the available ERC20 balance of the container is greater enough to support the withdrawal + if (amount > asset.balanceOf(address(this))) revert Errors.InsufficientERC20ToWithdraw(); + + // Interactions: withdraw by transferring the amount to the sender + asset.safeTransfer({ to: msg.sender, value: amount }); + + // Log the successful ERC-20 token withdrawal + emit AssetWithdrawn({ to: msg.sender, asset: address(asset), amount: amount }); + } + + /// @inheritdoc IContainer + function withdrawERC721(IERC721 collection, uint256 tokenId) public onlyOwner { + // Checks, Effects, Interactions: withdraw by transferring the token to the container owner + // Notes: + // - we're using `safeTransferFrom` as the owner can be an ERC-4337 smart account + // therefore the `onERC721Received` hook must be implemented + collection.safeTransferFrom(address(this), msg.sender, tokenId); + + // Log the successful ERC-721 token withdrawal + emit ERC721Withdrawn({ to: msg.sender, collection: address(collection), tokenId: tokenId }); + } + + /// @inheritdoc IContainer + function withdrawERC1155(IERC1155 collection, uint256[] memory ids, uint256[] memory amounts) public onlyOwner { + // Checks, Effects, Interactions: withdraw by transferring the tokens to the container owner + // Notes: + // - we're using `safeTransferFrom` as the owner can be an ERC-4337 smart account + // therefore the `onERC1155Received` hook must be implemented + // - depending on the length of the `ids` array, we're using `safeBatchTransferFrom` or `safeTransferFrom` + if (ids.length > 1) { + collection.safeBatchTransferFrom({ from: address(this), to: msg.sender, ids: ids, values: amounts, data: "" }); + } else { + collection.safeTransferFrom({ from: address(this), to: msg.sender, id: ids[0], value: amounts[0], data: "" }); + } + + // Log the successful ERC-1155 token withdrawal + emit ERC1155Withdrawn(msg.sender, address(collection), ids, amounts); + } + + /// @inheritdoc IContainer + function withdrawNative(uint256 amount) public onlyOwner { + // Checks: the native balance of the container minus the amount locked for operations is greater than the requested amount + if (amount > address(this).balance) revert Errors.InsufficientNativeToWithdraw(); + + // Interactions: withdraw by transferring the amount to the sender + (bool success,) = msg.sender.call{ value: amount }(""); + // Revert if the call failed + if (!success) revert Errors.NativeWithdrawFailed(); + + // Log the successful native token withdrawal + emit AssetWithdrawn({ to: msg.sender, asset: address(0), amount: amount }); + } + + /// @inheritdoc IModuleManager + function enableModule(address module) public override onlyOwner { + super.enableModule(module); + } + + /// @inheritdoc IModuleManager + function disableModule(address module) public override onlyOwner { + super.disableModule(module); + } + + /*////////////////////////////////////////////////////////////////////////// + CONSTANT FUNCTIONS + //////////////////////////////////////////////////////////////////////////*/ + + /// @inheritdoc IERC165 + function supportsInterface(bytes4 interfaceId) public pure override returns (bool) { + return interfaceId == type(IContainer).interfaceId || interfaceId == type(IERC165).interfaceId; + } + + /// @inheritdoc IERC721Receiver + function onERC721Received( + address, + address from, + uint256 tokenId, + bytes calldata + ) external override returns (bytes4) { + // Log the successful ERC-721 token receipt + emit ERC721Received(from, tokenId); + + return this.onERC721Received.selector; + } + + /// @inheritdoc IERC1155Receiver + function onERC1155Received( + address, + address from, + uint256 id, + uint256 value, + bytes calldata + ) external override returns (bytes4) { + // Log the successful ERC-1155 token receipt + emit ERC1155Received(from, id, value); + + return this.onERC1155Received.selector; + } + + /// @inheritdoc IERC1155Receiver + function onERC1155BatchReceived( + address, + address from, + uint256[] calldata ids, + uint256[] calldata values, + bytes calldata + ) external override returns (bytes4) { + for (uint256 i; i < ids.length; ++i) { + // Log the successful ERC-1155 token receipt + emit ERC1155Received(from, ids[i], values[i]); + } + + return this.onERC1155BatchReceived.selector; + } +} diff --git a/test/utils/Helpers.sol b/test/utils/Helpers.sol index 4d1fcf82..4fd00420 100644 --- a/test/utils/Helpers.sol +++ b/test/utils/Helpers.sol @@ -5,9 +5,8 @@ import { Types } from "./../../src/modules/invoice-module/libraries/Types.sol"; import { Helpers as InvoiceHelpers } from "./../../src/modules/invoice-module/libraries/Helpers.sol"; library Helpers { - function createInvoiceDataType(address recipient) public view returns (Types.Invoice memory) { + function createInvoiceDataType() public view returns (Types.Invoice memory) { return Types.Invoice({ - recipient: recipient, status: Types.Status.Pending, startTime: 0, endTime: uint40(block.timestamp) + 1 weeks,