diff --git a/x/capability/genesis.go b/x/capability/genesis.go index 07bf6df7ef68..9bcca42a35f0 100644 --- a/x/capability/genesis.go +++ b/x/capability/genesis.go @@ -9,7 +9,9 @@ import ( // InitGenesis initializes the capability module's state from a provided genesis // state. func InitGenesis(ctx sdk.Context, k keeper.Keeper, genState types.GenesisState) { - k.SetIndex(ctx, genState.Index) + if err := k.SetIndex(ctx, genState.Index); err != nil { + panic(err) + } // set owners for each index and initialize capability for _, genOwner := range genState.Owners { diff --git a/x/capability/keeper/keeper.go b/x/capability/keeper/keeper.go index cb4d31df0bcf..492d56c91591 100644 --- a/x/capability/keeper/keeper.go +++ b/x/capability/keeper/keeper.go @@ -117,12 +117,22 @@ func (k *Keeper) InitializeAndSeal(ctx sdk.Context) { k.sealed = true } -// SetIndex sets the index to one in InitChain -// Since it is an exported function, we check that index is indeed unset, before initializing -func (k Keeper) SetIndex(ctx sdk.Context, index uint64) { +// SetIndex sets the index to one (or greater) in InitChain according +// to the GenesisState. It must only be called once. +// It will panic if the provided index is 0, or if the index is already set. +func (k Keeper) SetIndex(ctx sdk.Context, index uint64) error { + if index == 0 { + panic("SetIndex requires index > 0") + } + latest := k.GetLatestIndex(ctx) + if latest > 0 { + panic("SetIndex requires index to not be set") + } + // set the global index to the passed index store := ctx.KVStore(k.storeKey) store.Set(types.KeyIndex, types.IndexToKey(index)) + return nil } // GetLatestIndex returns the latest index of the CapabilityKeeper @@ -156,7 +166,8 @@ func (k Keeper) GetOwners(ctx sdk.Context, index uint64) (types.CapabilityOwners } // InitializeCapability takes in an index and an owners array. It creates the capability in memory -// and sets the fwd and reverse keys for each owner in the memstore +// and sets the fwd and reverse keys for each owner in the memstore. +// It is used during initialization from genesis. func (k Keeper) InitializeCapability(ctx sdk.Context, index uint64, owners types.CapabilityOwners) { memStore := ctx.KVStore(k.memKey) @@ -278,14 +289,12 @@ func (sk ScopedKeeper) ReleaseCapability(ctx sdk.Context, cap *types.Capability) memStore := ctx.KVStore(sk.memKey) - // Set the forward mapping between the module and capability tuple and the + // Delete the forward mapping between the module and capability tuple and the // capability name in the memKVStore memStore.Delete(types.FwdCapabilityKey(sk.module, cap)) - // Set the reverse mapping between the module and capability name and the - // index in the in-memory store. Since marshalling and unmarshalling into a store - // will change memory address of capability, we simply store index as value here - // and retrieve the in-memory pointer to the capability from our map + // Delete the reverse mapping between the module and capability name and the + // index in the in-memory store. memStore.Delete(types.RevCapabilityKey(sk.module, name)) // remove owner @@ -298,7 +307,7 @@ func (sk ScopedKeeper) ReleaseCapability(ctx sdk.Context, cap *types.Capability) if len(capOwners.Owners) == 0 { // remove capability owner set prefixStore.Delete(indexKey) - // since no one ones capability, we can delete capability from map + // since no one owns capability, we can delete capability from map delete(sk.capMap, cap.GetIndex()) } else { // update capability owner set @@ -370,7 +379,7 @@ func (sk ScopedKeeper) GetOwners(ctx sdk.Context, name string) (*types.Capabilit // LookupModules returns all the module owners for a given capability // as a string array and the capability itself. -// The method returns an errors if either the capability or the owners cannot be +// The method returns an error if either the capability or the owners cannot be // retreived from the memstore. func (sk ScopedKeeper) LookupModules(ctx sdk.Context, name string) ([]string, *types.Capability, error) { cap, ok := sk.GetCapability(ctx, name) @@ -413,16 +422,13 @@ func (sk ScopedKeeper) getOwners(ctx sdk.Context, cap *types.Capability) *types. bz := prefixStore.Get(indexKey) - var owners *types.CapabilityOwners if len(bz) == 0 { - owners = types.NewCapabilityOwners() - } else { - var capOwners types.CapabilityOwners - sk.cdc.MustUnmarshalBinaryBare(bz, &capOwners) - owners = &capOwners + return types.NewCapabilityOwners() } - return owners + var capOwners types.CapabilityOwners + sk.cdc.MustUnmarshalBinaryBare(bz, &capOwners) + return &capOwners } func logger(ctx sdk.Context) log.Logger { diff --git a/x/capability/keeper/keeper_test.go b/x/capability/keeper/keeper_test.go index fdd795341010..c6f1d174bdb9 100644 --- a/x/capability/keeper/keeper_test.go +++ b/x/capability/keeper/keeper_test.go @@ -75,11 +75,15 @@ func (suite *KeeperTestSuite) TestInitializeAndSeal() { func (suite *KeeperTestSuite) TestNewCapability() { sk := suite.keeper.ScopeToModule(banktypes.ModuleName) + got, ok := sk.GetCapability(suite.ctx, "transfer") + suite.Require().False(ok) + suite.Require().Nil(got) + cap, err := sk.NewCapability(suite.ctx, "transfer") suite.Require().NoError(err) suite.Require().NotNil(cap) - got, ok := sk.GetCapability(suite.ctx, "transfer") + got, ok = sk.GetCapability(suite.ctx, "transfer") suite.Require().True(ok) suite.Require().Equal(cap, got) suite.Require().True(cap == got, "expected memory addresses to be equal") @@ -88,9 +92,19 @@ func (suite *KeeperTestSuite) TestNewCapability() { suite.Require().False(ok) suite.Require().Nil(got) - cap, err = sk.NewCapability(suite.ctx, "transfer") + got, ok = sk.GetCapability(suite.ctx, "transfer") + suite.Require().True(ok) + suite.Require().Equal(cap, got) + suite.Require().True(cap == got, "expected memory addresses to be equal") + + cap2, err := sk.NewCapability(suite.ctx, "transfer") suite.Require().Error(err) - suite.Require().Nil(cap) + suite.Require().Nil(cap2) + + got, ok = sk.GetCapability(suite.ctx, "transfer") + suite.Require().True(ok) + suite.Require().Equal(cap, got) + suite.Require().True(cap == got, "expected memory addresses to be equal") } func (suite *KeeperTestSuite) TestOriginalCapabilityKeeper() { @@ -111,7 +125,8 @@ func (suite *KeeperTestSuite) TestAuthenticateCapability() { suite.Require().NoError(err) suite.Require().NotNil(cap1) - forgedCap := types.NewCapability(0) // index should be the same index as the first capability + forgedCap := types.NewCapability(cap1.Index) // index should be the same index as the first capability + suite.Require().False(sk1.AuthenticateCapability(suite.ctx, forgedCap, "transfer")) suite.Require().False(sk2.AuthenticateCapability(suite.ctx, forgedCap, "transfer")) cap2, err := sk2.NewCapability(suite.ctx, "bond") @@ -176,14 +191,15 @@ func (suite *KeeperTestSuite) TestGetOwners() { // Ensure all scoped keepers can get owners for _, sk := range sks { owners, ok := sk.GetOwners(suite.ctx, "transfer") - mods, cap, err := sk.LookupModules(suite.ctx, "transfer") + mods, gotCap, err := sk.LookupModules(suite.ctx, "transfer") suite.Require().True(ok, "could not retrieve owners") suite.Require().NotNil(owners, "owners is nil") suite.Require().NoError(err, "could not retrieve modules") - suite.Require().NotNil(cap, "capability is nil") + suite.Require().NotNil(gotCap, "capability is nil") suite.Require().NotNil(mods, "modules is nil") + suite.Require().Equal(cap, gotCap, "caps not equal") suite.Require().Equal(len(expectedOrder), len(owners.Owners), "length of owners is unexpected") for i, o := range owners.Owners {