Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Bank hooks #2

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions simapp/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/cometbft/cometbft/libs/log"
tmproto "github.com/cometbft/cometbft/proto/tendermint/types"
tmtypes "github.com/cometbft/cometbft/types"
bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper"
"github.com/stretchr/testify/require"

"cosmossdk.io/math"
Expand Down Expand Up @@ -246,3 +247,17 @@ func NewTestNetworkFixture() network.TestFixture {
},
}
}

// FundModuleAccount is a utility function that funds a module account by
// minting and sending the coins to the address. This should be used for testing
// purposes only!
//
// TODO: Instead of using the mint module account, which has the
// permission of minting, create a "faucet" account. (@fdymylja)
func FundModuleAccount(bankKeeper bankkeeper.Keeper, ctx sdk.Context, recipientMod string, amounts sdk.Coins) error {
if err := bankKeeper.MintCoins(ctx, minttypes.ModuleName, amounts); err != nil {
return err
}

return bankKeeper.SendCoinsFromModuleToModule(ctx, minttypes.ModuleName, recipientMod, amounts)
}
120 changes: 120 additions & 0 deletions x/bank/app_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package bank_test

import (
"fmt"
"testing"

tmproto "github.com/cometbft/cometbft/proto/tendermint/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -447,3 +449,121 @@ func TestMsgSetSendEnabled(t *testing.T) {
})
}
}

var _ types.BankHooks = &MockBankHooksReceiver{}

// BankHooks event hooks for bank (noalias)
type MockBankHooksReceiver struct{}

// Mock BlockBeforeSend bank hook that doesn't allow the sending of exactly 100 coins of any denom.
func (h *MockBankHooksReceiver) BlockBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for _, coin := range amount {
if coin.Amount.Equal(sdk.NewInt(100)) {
return fmt.Errorf("not allowed; expected %v, got: %v", 100, coin.Amount)
}
}
return nil
}

// variable for counting `TrackBeforeSend`
var (
countTrackBeforeSend = 0
expNextCount = 1
)

// Mock TrackBeforeSend bank hook that simply tracks the sending of exactly 50 coins of any denom.
func (h *MockBankHooksReceiver) TrackBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) {
for _, coin := range amount {
if coin.Amount.Equal(sdk.NewInt(50)) {
countTrackBeforeSend += 1
}
}
}

func TestHooks(t *testing.T) {
acc1 := authtypes.NewBaseAccountWithAddress(addr1)

genAccs := []authtypes.GenesisAccount{acc1}
s := createTestSuite(t, genAccs)

ctx := s.App.BaseApp.NewContext(false, tmproto.Header{})

require.NoError(t, testutil.FundAccount(s.BankKeeper, ctx, addr1, sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(1000)))))
require.NoError(t, testutil.FundModuleAccount(s.BankKeeper, ctx, stakingtypes.BondedPoolName, sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(1000)))))

// create a valid send amount which is 1 coin, and an invalidSendAmount which is 100 coins
validSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(1)))
triggerTrackSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(50)))
invalidBlockSendAmount := sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(100)))

// setup our mock bank hooks receiver that prevents the send of 100 coins
bankHooksReceiver := MockBankHooksReceiver{}
baseBankKeeper, ok := s.BankKeeper.(bankkeeper.BaseKeeper)
require.True(t, ok)
bankkeeper.UnsafeSetHooks(
&baseBankKeeper, types.NewMultiBankHooks(&bankHooksReceiver),
)
s.BankKeeper = baseBankKeeper

// try sending a validSendAmount and it should work
err := s.BankKeeper.SendCoins(ctx, addr1, addr2, validSendAmount)
require.NoError(t, err)

// try sending an trigger track send amount and it should work
err = s.BankKeeper.SendCoins(ctx, addr1, addr2, triggerTrackSendAmount)
require.NoError(t, err)

require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// try sending an invalidSendAmount and it should not work
err = s.BankKeeper.SendCoins(ctx, addr1, addr2, invalidBlockSendAmount)
require.Error(t, err)

// make sure that account to module doesn't bypass hook
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.SendCoinsFromAccountToModule(ctx, addr1, stakingtypes.BondedPoolName, triggerTrackSendAmount)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that module to account doesn't bypass hook
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.SendCoinsFromModuleToAccount(ctx, stakingtypes.BondedPoolName, addr1, triggerTrackSendAmount)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that module to module doesn't bypass hook
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, invalidBlockSendAmount)
// there should be no error since module to module does not call block before send hooks
require.NoError(t, err)
err = s.BankKeeper.SendCoinsFromModuleToModule(ctx, stakingtypes.BondedPoolName, stakingtypes.NotBondedPoolName, triggerTrackSendAmount)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that DelegateCoins doesn't bypass the hook
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), invalidBlockSendAmount)
require.Error(t, err)
err = s.BankKeeper.DelegateCoins(ctx, addr1, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), triggerTrackSendAmount)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++

// make sure that UndelegateCoins doesn't bypass the hook
err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, validSendAmount)
require.NoError(t, err)
err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, invalidBlockSendAmount)
require.Error(t, err)

err = s.BankKeeper.UndelegateCoins(ctx, s.AccountKeeper.GetModuleAddress(stakingtypes.BondedPoolName), addr1, triggerTrackSendAmount)
require.Equal(t, countTrackBeforeSend, expNextCount)
expNextCount++
}
24 changes: 24 additions & 0 deletions x/bank/keeper/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package keeper

import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/bank/types"
)

// Implements StakingHooks interface
var _ types.BankHooks = BaseSendKeeper{}

// TrackBeforeSend executes the TrackBeforeSend hook if registered.
func (k BaseSendKeeper) TrackBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) {
if k.hooks != nil {
k.hooks.TrackBeforeSend(ctx, from, to, amount)
}
}

// BlockBeforeSend executes the BlockBeforeSend hook if registered.
func (k BaseSendKeeper) BlockBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
if k.hooks != nil {
return k.hooks.BlockBeforeSend(ctx, from, to, amount)
}
return nil
}
11 changes: 11 additions & 0 deletions x/bank/keeper/internal_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package keeper

import "github.com/cosmos/cosmos-sdk/x/bank/types"

// UnsafeSetHooks updates the x/bank keeper's hooks, overriding any potential
// pre-existing hooks.
//
// WARNING: this function should only be used in tests.
func UnsafeSetHooks(k *BaseKeeper, h types.BankHooks) {
k.hooks = h
}
21 changes: 18 additions & 3 deletions x/bank/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ func (k BaseKeeper) DelegateCoins(ctx sdk.Context, delegatorAddr, moduleAccAddr
return sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.BlockBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)
if err != nil {
return err
}
// call the TrackBeforeSend hooks and the BlockBeforeSend hooks
k.TrackBeforeSend(ctx, delegatorAddr, moduleAccAddr, amt)

balances := sdk.NewCoins()

for _, coin := range amt {
Expand All @@ -174,7 +181,7 @@ func (k BaseKeeper) DelegateCoins(ctx sdk.Context, delegatorAddr, moduleAccAddr
types.NewCoinSpentEvent(delegatorAddr, amt),
)

err := k.addCoins(ctx, moduleAccAddr, amt)
err = k.addCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand All @@ -197,7 +204,15 @@ func (k BaseKeeper) UndelegateCoins(ctx sdk.Context, moduleAccAddr, delegatorAdd
return sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, amt.String())
}

err := k.subUnlockedCoins(ctx, moduleAccAddr, amt)
// call the TrackBeforeSend hooks and the BlockBeforeSend hooks
err := k.BlockBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)
if err != nil {
return err
}

k.TrackBeforeSend(ctx, moduleAccAddr, delegatorAddr, amt)

err = k.subUnlockedCoins(ctx, moduleAccAddr, amt)
if err != nil {
return err
}
Expand Down Expand Up @@ -343,7 +358,7 @@ func (k BaseKeeper) SendCoinsFromModuleToModule(
panic(sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", recipientModule))
}

return k.SendCoins(ctx, senderAddr, recipientAcc.GetAddress(), amt)
return k.SendCoinsWithoutBlockHook(ctx, senderAddr, recipientAcc.GetAddress(), amt)
}

// SendCoinsFromAccountToModule transfers coins from an AccAddress to a ModuleAccount.
Expand Down
32 changes: 32 additions & 0 deletions x/bank/keeper/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type BaseSendKeeper struct {
cdc codec.BinaryCodec
ak types.AccountKeeper
storeKey storetypes.StoreKey
hooks types.BankHooks

// list of addresses that are restricted from receiving transactions
blockedAddrs map[string]bool
Expand Down Expand Up @@ -83,6 +84,17 @@ func NewBaseSendKeeper(
}
}

// Set the bank hooks
func (k *BaseSendKeeper) SetHooks(bh types.BankHooks) *BaseSendKeeper {
if k.hooks != nil {
panic("cannot set bank hooks twice")
}

k.hooks = bh

return k
}

// GetAuthority returns the x/bank module's authority.
func (k BaseSendKeeper) GetAuthority() string {
return k.authority
Expand Down Expand Up @@ -187,9 +199,29 @@ func (k BaseSendKeeper) InputOutputCoins(ctx sdk.Context, inputs []types.Input,
return nil
}

// SendCoinsWithoutBlockHook calls sendCoins without calling the `BlockBeforeSend` hook.
func (k BaseSendKeeper) SendCoinsWithoutBlockHook(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
return k.sendCoins(ctx, fromAddr, toAddr, amt)
}

// SendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
// BlockBeforeSend hook should always be called before the TrackBeforeSend hook.
err := k.BlockBeforeSend(ctx, fromAddr, toAddr, amt)
if err != nil {
return err
}

return k.sendCoins(ctx, fromAddr, toAddr, amt)
}

// SendCoins transfers amt coins from a sending account to a receiving account.
// An error is returned upon failure.
func (k BaseSendKeeper) sendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
// call the TrackBeforeSend hooks
k.TrackBeforeSend(ctx, fromAddr, toAddr, amt)

err := k.subUnlockedCoins(ctx, fromAddr, amt)
if err != nil {
return err
Expand Down
12 changes: 12 additions & 0 deletions x/bank/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ type AccountKeeper interface {
SetModuleAccount(ctx sdk.Context, macc types.ModuleAccountI)
GetModulePermissions() map[string]types.PermissionsForAddress
}

// Event Hooks
// These can be utilized to communicate between a bank keeper and another
// keeper which must take particular actions when sends happen.
// The second keeper must implement this interface, which then the
// bank keeper can call.

// BankHooks event hooks for bank sends
type BankHooks interface {
TrackBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) // Must be before any send is executed
BlockBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error // Must be before any send is executed
}
31 changes: 31 additions & 0 deletions x/bank/types/hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package types

import (
sdk "github.com/cosmos/cosmos-sdk/types"
)

// MultiBankHooks combine multiple bank hooks, all hook functions are run in array sequence
type MultiBankHooks []BankHooks

// NewMultiBankHooks takes a list of BankHooks and returns a MultiBankHooks
func NewMultiBankHooks(hooks ...BankHooks) MultiBankHooks {
return hooks
}

// TrackBeforeSend runs the TrackBeforeSend hooks in order for each BankHook in a MultiBankHooks struct
func (h MultiBankHooks) TrackBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) {
for i := range h {
h[i].TrackBeforeSend(ctx, from, to, amount)
}
}

// BlockBeforeSend runs the BlockBeforeSend hooks in order for each BankHook in a MultiBankHooks struct
func (h MultiBankHooks) BlockBeforeSend(ctx sdk.Context, from, to sdk.AccAddress, amount sdk.Coins) error {
for i := range h {
err := h[i].BlockBeforeSend(ctx, from, to, amount)
if err != nil {
return err
}
}
return nil
}
Loading