diff --git a/modules/apps/27-interchain-accounts/keeper/handshake.go b/modules/apps/27-interchain-accounts/keeper/handshake.go index 6b929f547d9..607f9524256 100644 --- a/modules/apps/27-interchain-accounts/keeper/handshake.go +++ b/modules/apps/27-interchain-accounts/keeper/handshake.go @@ -6,6 +6,7 @@ import ( capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types" "github.com/cosmos/ibc-go/v2/modules/apps/27-interchain-accounts/types" + connectiontypes "github.com/cosmos/ibc-go/v2/modules/core/03-connection/types" channeltypes "github.com/cosmos/ibc-go/v2/modules/core/04-channel/types" porttypes "github.com/cosmos/ibc-go/v2/modules/core/05-port/types" host "github.com/cosmos/ibc-go/v2/modules/core/24-host" @@ -31,10 +32,25 @@ func (k Keeper) OnChanOpenInit( version string, ) error { if order != channeltypes.ORDERED { - return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "invalid channel ordering: %s, expected %s", order.String(), channeltypes.ORDERED.String()) + return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "expected %s channel, got %s", channeltypes.ORDERED, order) } + + connSequence, err := types.ParseControllerConnSequence(portID) + if err != nil { + return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, portID) + } + + counterpartyConnSequence, err := types.ParseHostConnSequence(portID) + if err != nil { + return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, portID) + } + + if err := k.validateControllerPortParams(ctx, channelID, portID, connSequence, counterpartyConnSequence); err != nil { + return sdkerrors.Wrapf(err, "failed to validate controller port %s", portID) + } + if counterparty.PortId != types.PortID { - return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "counterparty port-id must be '%s', (%s != %s)", types.PortID, counterparty.PortId, types.PortID) + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "expected %s, got %s", types.PortID, counterparty.PortId) } if err := types.ValidateVersion(version); err != nil { @@ -43,7 +59,7 @@ func (k Keeper) OnChanOpenInit( existingChannelID, found := k.GetActiveChannel(ctx, portID) if found { - return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "existing active channel (%s) for portID (%s)", existingChannelID, portID) + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "existing active channel %s for portID %s", existingChannelID, portID) } // Claim channel capability passed back by IBC module @@ -70,7 +86,25 @@ func (k Keeper) OnChanOpenTry( counterpartyVersion string, ) error { if order != channeltypes.ORDERED { - return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "invalid channel ordering: %s, expected %s", order.String(), channeltypes.ORDERED.String()) + return sdkerrors.Wrapf(channeltypes.ErrInvalidChannelOrdering, "expected %s channel, got %s", channeltypes.ORDERED, order) + } + + if portID != types.PortID { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "expected %s, got %s", types.PortID, portID) + } + + connSequence, err := types.ParseHostConnSequence(counterparty.PortId) + if err != nil { + return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, counterparty.PortId) + } + + counterpartyConnSequence, err := types.ParseControllerConnSequence(counterparty.PortId) + if err != nil { + return sdkerrors.Wrapf(err, "expected format %s, got %s", types.ControllerPortFormat, counterparty.PortId) + } + + if err := k.validateControllerPortParams(ctx, channelID, portID, connSequence, counterpartyConnSequence); err != nil { + return sdkerrors.Wrapf(err, "failed to validate controller port %s", counterparty.PortId) } if err := types.ValidateVersion(version); err != nil { @@ -89,7 +123,11 @@ func (k Keeper) OnChanOpenTry( // Check to ensure that the version string contains the expected address generated from the Counterparty portID accAddr := types.GenerateAddress(k.accountKeeper.GetModuleAddress(types.ModuleName), counterparty.PortId) - parsedAddr := types.ParseAddressFromVersion(version) + parsedAddr, err := types.ParseAddressFromVersion(version) + if err != nil { + return sdkerrors.Wrapf(err, "expected format , got %s", types.Delimiter, version) + } + if parsedAddr != accAddr.String() { return sdkerrors.Wrapf(types.ErrInvalidAccountAddress, "version contains invalid account address: expected %s, got %s", parsedAddr, accAddr) } @@ -116,7 +154,11 @@ func (k Keeper) OnChanOpenAck( k.SetActiveChannel(ctx, portID, channelID) - accAddr := types.ParseAddressFromVersion(counterpartyVersion) + accAddr, err := types.ParseAddressFromVersion(counterpartyVersion) + if err != nil { + return sdkerrors.Wrapf(err, "expected format , got %s", types.Delimiter, counterpartyVersion) + } + k.SetInterchainAccountAddress(ctx, portID, accAddr) return nil @@ -130,3 +172,37 @@ func (k Keeper) OnChanOpenConfirm( ) error { return nil } + +// validateControllerPortParams asserts the provided connection sequence and counterparty connection sequence +// match that of the associated connection stored in state +func (k Keeper) validateControllerPortParams(ctx sdk.Context, channelID, portID string, connectionSeq, counterpartyConnectionSeq uint64) error { + channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID) + if !found { + return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID %s channel ID %s", portID, channelID) + } + + counterpartyHops, found := k.channelKeeper.CounterpartyHops(ctx, channel) + if !found { + return sdkerrors.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) + } + + connSeq, err := connectiontypes.ParseConnectionSequence(channel.ConnectionHops[0]) + if err != nil { + return sdkerrors.Wrapf(err, "failed to parse connection sequence %s", channel.ConnectionHops[0]) + } + + counterpartyConnSeq, err := connectiontypes.ParseConnectionSequence(counterpartyHops[0]) + if err != nil { + return sdkerrors.Wrapf(err, "failed to parse counterparty connection sequence %s", counterpartyHops[0]) + } + + if connSeq != connectionSeq { + return sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "sequence mismatch, expected %d, got %d", connSeq, connectionSeq) + } + + if counterpartyConnSeq != counterpartyConnectionSeq { + return sdkerrors.Wrapf(connectiontypes.ErrInvalidConnection, "counterparty sequence mismatch, expected %d, got %d", counterpartyConnSeq, counterpartyConnectionSeq) + } + + return nil +} diff --git a/modules/apps/27-interchain-accounts/keeper/handshake_test.go b/modules/apps/27-interchain-accounts/keeper/handshake_test.go index ba2a1eea9ca..63afb203f11 100644 --- a/modules/apps/27-interchain-accounts/keeper/handshake_test.go +++ b/modules/apps/27-interchain-accounts/keeper/handshake_test.go @@ -23,33 +23,94 @@ func (suite *KeeperTestSuite) TestOnChanOpenInit() { }{ { - "success", func() {}, true, + "success", + func() { + path.EndpointA.SetChannel(*channel) + }, + true, }, { - "invalid order - UNORDERED", func() { + "invalid order - UNORDERED", + func() { channel.Ordering = channeltypes.UNORDERED - }, false, + }, + false, }, { - "invalid counterparty port ID", func() { - channel.Counterparty.PortId = ibctesting.MockPort - }, false, + "invalid port ID", + func() { + path.EndpointA.ChannelConfig.PortID = "invalid-port-id" + }, + false, }, { - "invalid version", func() { + "invalid counterparty port ID", + func() { + path.EndpointA.SetChannel(*channel) + channel.Counterparty.PortId = "invalid-port-id" + }, + false, + }, + { + "invalid version", + func() { + path.EndpointA.SetChannel(*channel) channel.Version = "version" - }, false, + }, + false, }, { - "channel is already active", func() { + "channel not found", + func() { + path.EndpointA.ChannelID = "invalid-channel-id" + }, + false, + }, + { + "connection not found", + func() { + channel.ConnectionHops = []string{"invalid-connnection-id"} + path.EndpointA.SetChannel(*channel) + }, + false, + }, + { + "invalid connection sequence", + func() { + portID, err := types.GeneratePortID(TestOwnerAddress, "connection-1", "connection-0") + suite.Require().NoError(err) + + path.EndpointA.ChannelConfig.PortID = portID + path.EndpointA.SetChannel(*channel) + }, + false, + }, + { + "invalid counterparty connection sequence", + func() { + portID, err := types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-1") + suite.Require().NoError(err) + + path.EndpointA.ChannelConfig.PortID = portID + path.EndpointA.SetChannel(*channel) + }, + false, + }, + { + "channel is already active", + func() { suite.chainA.GetSimApp().ICAKeeper.SetActiveChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) - }, false, + }, + false, }, { - "capability already claimed", func() { + "capability already claimed", + func() { + path.EndpointA.SetChannel(*channel) err := suite.chainA.GetSimApp().ScopedICAKeeper.ClaimCapability(suite.chainA.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) suite.Require().NoError(err) - }, false, + }, + false, }, } @@ -97,7 +158,6 @@ func (suite *KeeperTestSuite) TestOnChanOpenInit() { } } -// ChainA is controller, ChainB is host chain func (suite *KeeperTestSuite) TestOnChanOpenTry() { var ( channel *channeltypes.Channel @@ -113,33 +173,105 @@ func (suite *KeeperTestSuite) TestOnChanOpenTry() { }{ { - "success", func() {}, true, + "success", + func() { + path.EndpointB.SetChannel(*channel) + }, + true, }, { - "invalid order - UNORDERED", func() { + "invalid order - UNORDERED", + func() { channel.Ordering = channeltypes.UNORDERED - }, false, + }, + false, + }, + { + "invalid port", + func() { + path.EndpointB.ChannelConfig.PortID = "invalid-port-id" + }, + false, + }, + { + "invalid counterparty port", + func() { + channel.Counterparty.PortId = "invalid-port-id" + }, + false, }, { - "invalid version", func() { + "channel not found", + func() { + path.EndpointB.ChannelID = "invalid-channel-id" + }, + false, + }, + { + "connection not found", + func() { + channel.ConnectionHops = []string{"invalid-connnection-id"} + path.EndpointB.SetChannel(*channel) + }, + false, + }, + { + "invalid connection sequence", + func() { + portID, err := types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-1") + suite.Require().NoError(err) + + channel.Counterparty.PortId = portID + path.EndpointB.SetChannel(*channel) + }, + false, + }, + { + "invalid counterparty connection sequence", + func() { + portID, err := types.GeneratePortID(TestOwnerAddress, "connection-1", "connection-0") + suite.Require().NoError(err) + + channel.Counterparty.PortId = portID + path.EndpointB.SetChannel(*channel) + }, + false, + }, + { + "invalid version", + func() { channel.Version = "version" - }, false, + path.EndpointB.SetChannel(*channel) + }, + false, }, { - "invalid counterparty version", func() { + "invalid counterparty version", + func() { counterpartyVersion = "version" - }, false, + path.EndpointB.SetChannel(*channel) + }, + false, }, { - "capability already claimed", func() { + "capability already claimed", + func() { + path.EndpointB.SetChannel(*channel) err := suite.chainB.GetSimApp().ScopedICAKeeper.ClaimCapability(suite.chainB.GetContext(), chanCap, host.ChannelCapabilityPath(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)) suite.Require().NoError(err) - }, false, + }, + false, }, { - "invalid account address", func() { - channel.Counterparty.PortId = "invalid-port-id" - }, false, + "invalid account address", + func() { + portID, err := types.GeneratePortID("invalid-owner-addr", "connection-0", "connection-0") + suite.Require().NoError(err) + + channel.Counterparty.PortId = portID + path.EndpointB.SetChannel(*channel) + }, + false, }, } @@ -155,6 +287,10 @@ func (suite *KeeperTestSuite) TestOnChanOpenTry() { err := InitInterchainAccount(path.EndpointA, TestOwnerAddress) suite.Require().NoError(err) + // set the channel id on host + channelSequence := path.EndpointB.Chain.App.GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(path.EndpointB.Chain.GetContext()) + path.EndpointB.ChannelID = channeltypes.FormatChannelIdentifier(channelSequence) + // default values counterparty := channeltypes.NewCounterparty(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) channel = &channeltypes.Channel{ diff --git a/modules/apps/27-interchain-accounts/keeper/keeper_test.go b/modules/apps/27-interchain-accounts/keeper/keeper_test.go index 1930f327bd9..35d61cb0295 100644 --- a/modules/apps/27-interchain-accounts/keeper/keeper_test.go +++ b/modules/apps/27-interchain-accounts/keeper/keeper_test.go @@ -1,7 +1,6 @@ package keeper_test import ( - "fmt" "testing" sdk "github.com/cosmos/cosmos-sdk/types" @@ -21,7 +20,7 @@ var ( // TestOwnerAddress defines a reusable bech32 address for testing purposes TestOwnerAddress = "cosmos17dtl0mjt3t77kpuhg2edqzjpszulwhgzuj9ljs" // TestPortID defines a resuable port identifier for testing purposes - TestPortID = fmt.Sprintf("%s-0-0-%s", types.VersionPrefix, TestOwnerAddress) + TestPortID, _ = types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-0") // TestVersion defines a resuable interchainaccounts version string for testing purposes TestVersion = types.NewAppVersion(types.VersionPrefix, TestAccAddress.String()) ) diff --git a/modules/apps/27-interchain-accounts/module_test.go b/modules/apps/27-interchain-accounts/module_test.go index 778a676f20b..a2ff103d3eb 100644 --- a/modules/apps/27-interchain-accounts/module_test.go +++ b/modules/apps/27-interchain-accounts/module_test.go @@ -1,7 +1,6 @@ package interchain_accounts_test import ( - "fmt" "testing" sdk "github.com/cosmos/cosmos-sdk/types" @@ -21,7 +20,7 @@ var ( // TestOwnerAddress defines a reusable bech32 address for testing purposes TestOwnerAddress = "cosmos17dtl0mjt3t77kpuhg2edqzjpszulwhgzuj9ljs" // TestPortID defines a resuable port identifier for testing purposes - TestPortID = fmt.Sprintf("%s-0-0-%s", types.VersionPrefix, TestOwnerAddress) + TestPortID, _ = types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-0") // TestVersion defines a resuable interchainaccounts version string for testing purposes TestVersion = types.NewAppVersion(types.VersionPrefix, TestAccAddress.String()) ) @@ -124,6 +123,9 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenInit() { Version: types.VersionPrefix, } + // set channel + path.EndpointA.SetChannel(*channel) + module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), types.PortID) suite.Require().NoError(err) @@ -142,7 +144,6 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenInit() { func (suite *InterchainAccountsTestSuite) TestOnChanOpenTry() { suite.SetupTest() // reset path := NewICAPath(suite.chainA, suite.chainB) - counterpartyVersion := types.VersionPrefix suite.coordinator.SetupConnections(path) err := InitInterchainAccount(path.EndpointA, TestOwnerAddress) @@ -158,24 +159,29 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenTry() { Version: types.VersionPrefix, } - module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), types.PortID) + // set channel + channelSequence := path.EndpointB.Chain.App.GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(path.EndpointB.Chain.GetContext()) + path.EndpointB.ChannelID = channeltypes.FormatChannelIdentifier(channelSequence) + path.EndpointB.SetChannel(*channel) + + module, _, err := suite.chainB.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainB.GetContext(), types.PortID) suite.Require().NoError(err) - chanCap, err := suite.chainA.App.GetScopedIBCKeeper().NewCapability(suite.chainA.GetContext(), host.ChannelCapabilityPath(ibctesting.TransferPort, path.EndpointA.ChannelID)) + chanCap, err := suite.chainB.App.GetScopedIBCKeeper().NewCapability(suite.chainB.GetContext(), host.ChannelCapabilityPath(ibctesting.TransferPort, path.EndpointB.ChannelID)) suite.Require().NoError(err) - cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + cbs, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module) suite.Require().True(ok) - err = cbs.OnChanOpenTry(suite.chainA.GetContext(), channel.Ordering, channel.GetConnectionHops(), - path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, chanCap, channel.Counterparty, channel.GetVersion(), counterpartyVersion, + err = cbs.OnChanOpenTry(suite.chainB.GetContext(), channel.Ordering, channel.GetConnectionHops(), + path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, chanCap, channel.Counterparty, TestVersion, types.VersionPrefix, ) + suite.Require().NoError(err) } func (suite *InterchainAccountsTestSuite) TestOnChanOpenAck() { suite.SetupTest() // reset path := NewICAPath(suite.chainA, suite.chainB) - counterpartyVersion := types.VersionPrefix suite.coordinator.SetupConnections(path) err := InitInterchainAccount(path.EndpointA, TestOwnerAddress) @@ -190,7 +196,7 @@ func (suite *InterchainAccountsTestSuite) TestOnChanOpenAck() { cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) suite.Require().True(ok) - err = cbs.OnChanOpenAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyVersion) + err = cbs.OnChanOpenAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, TestVersion) suite.Require().NoError(err) } diff --git a/modules/apps/27-interchain-accounts/types/account.go b/modules/apps/27-interchain-accounts/types/account.go index a9f1fdcde1a..87cd0634900 100644 --- a/modules/apps/27-interchain-accounts/types/account.go +++ b/modules/apps/27-interchain-accounts/types/account.go @@ -21,11 +21,6 @@ func GenerateAddress(moduleAccAddr sdk.AccAddress, portID string) sdk.AccAddress return sdk.AccAddress(sdkaddress.Derive(moduleAccAddr, []byte(portID))) } -// ParseAddressFromVersion trims the interchainaccounts version prefix and returns the associated account address -func ParseAddressFromVersion(version string) string { - return strings.TrimPrefix(version, fmt.Sprint(VersionPrefix, Delimiter)) -} - // GeneratePortID generates the portID for a specific owner // on the controller chain in the format: // @@ -47,7 +42,12 @@ func GeneratePortID(owner, connectionID, counterpartyConnectionID string) (strin return "", sdkerrors.Wrap(err, "invalid counterparty connection identifier") } - return fmt.Sprintf("%s-%d-%d-%s", VersionPrefix, connectionSeq, counterpartyConnectionSeq, owner), nil + return fmt.Sprint( + VersionPrefix, Delimiter, + connectionSeq, Delimiter, + counterpartyConnectionSeq, Delimiter, + owner, + ), nil } type InterchainAccountI interface { diff --git a/modules/apps/27-interchain-accounts/types/account_test.go b/modules/apps/27-interchain-accounts/types/account_test.go index acca53e39c4..751bc0f0fa6 100644 --- a/modules/apps/27-interchain-accounts/types/account_test.go +++ b/modules/apps/27-interchain-accounts/types/account_test.go @@ -17,6 +17,8 @@ import ( var ( // TestOwnerAddress defines a reusable bech32 address for testing purposes TestOwnerAddress = "cosmos17dtl0mjt3t77kpuhg2edqzjpszulwhgzuj9ljs" + // TestPortID defines a resuable port identifier for testing purposes + TestPortID, _ = types.GeneratePortID(TestOwnerAddress, "connection-0", "connection-0") ) type TypesTestSuite struct { @@ -47,13 +49,6 @@ func (suite *TypesTestSuite) TestGenerateAddress() { suite.Require().NotEmpty(accAddr) } -func (suite *TypesTestSuite) TestParseAddressFromVersion() { - version := types.NewAppVersion(types.VersionPrefix, TestOwnerAddress) - - addr := types.ParseAddressFromVersion(version) - suite.Require().Equal(TestOwnerAddress, addr) -} - func (suite *TypesTestSuite) TestGeneratePortID() { var ( path *ibctesting.Path @@ -69,7 +64,7 @@ func (suite *TypesTestSuite) TestGeneratePortID() { { "success", func() {}, - fmt.Sprintf("%s-0-0-%s", types.VersionPrefix, TestOwnerAddress), + fmt.Sprint(types.VersionPrefix, types.Delimiter, "0", types.Delimiter, "0", types.Delimiter, TestOwnerAddress), true, }, { @@ -77,7 +72,7 @@ func (suite *TypesTestSuite) TestGeneratePortID() { func() { path.EndpointA.ConnectionID = "connection-1" }, - fmt.Sprintf("%s-1-0-%s", types.VersionPrefix, TestOwnerAddress), + fmt.Sprint(types.VersionPrefix, types.Delimiter, "1", types.Delimiter, "0", types.Delimiter, TestOwnerAddress), true, }, { diff --git a/modules/apps/27-interchain-accounts/types/expected_keepers.go b/modules/apps/27-interchain-accounts/types/expected_keepers.go index 8b355e504a0..73a9b2ffe85 100644 --- a/modules/apps/27-interchain-accounts/types/expected_keepers.go +++ b/modules/apps/27-interchain-accounts/types/expected_keepers.go @@ -31,6 +31,7 @@ type ChannelKeeper interface { SendPacket(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error ChanCloseInit(ctx sdk.Context, portID, channelID string, chanCap *capabilitytypes.Capability) error ChanOpenInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID string, portCap *capabilitytypes.Capability, counterparty channeltypes.Counterparty, version string) (string, *capabilitytypes.Capability, error) + CounterpartyHops(ctx sdk.Context, channel channeltypes.Channel) ([]string, bool) } // ClientKeeper defines the expected IBC client keeper diff --git a/modules/apps/27-interchain-accounts/types/keys.go b/modules/apps/27-interchain-accounts/types/keys.go index 8640226191f..d8e3e1c1ec4 100644 --- a/modules/apps/27-interchain-accounts/types/keys.go +++ b/modules/apps/27-interchain-accounts/types/keys.go @@ -2,6 +2,12 @@ package types import ( "fmt" + "strconv" + "strings" + + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + + porttypes "github.com/cosmos/ibc-go/v2/modules/core/05-port/types" ) const ( @@ -24,7 +30,11 @@ const ( QuerierRoute = ModuleName // Delimiter is the delimiter used for the interchain accounts version string - Delimiter = "|" + Delimiter = "." + + // ControllerPortFormat is the expected port identifier format to which controller chains must conform + // See (TODO: Link to spec when updated) + ControllerPortFormat = "..." ) var ( @@ -46,3 +56,45 @@ func KeyActiveChannel(portID string) []byte { func KeyOwnerAccount(portID string) []byte { return []byte(fmt.Sprintf("owner/%s", portID)) } + +// ParseControllerConnSequence attempts to parse the controller connection sequence from the provided port identifier +// The port identifier must match the controller chain format outlined in (TODO: link spec), otherwise an empty string is returned +func ParseControllerConnSequence(portID string) (uint64, error) { + s := strings.Split(portID, Delimiter) + if len(s) != 4 { + return 0, sdkerrors.Wrap(porttypes.ErrInvalidPort, "failed to parse port identifier") + } + + seq, err := strconv.ParseUint(s[1], 10, 64) + if err != nil { + return 0, sdkerrors.Wrapf(err, "failed to parse connection sequence (%s)", s[1]) + } + + return seq, nil +} + +// ParseHostConnSequence attempts to parse the host connection sequence from the provided port identifier +// The port identifier must match the controller chain format outlined in (TODO: link spec), otherwise an empty string is returned +func ParseHostConnSequence(portID string) (uint64, error) { + s := strings.Split(portID, Delimiter) + if len(s) != 4 { + return 0, sdkerrors.Wrap(porttypes.ErrInvalidPort, "failed to parse port identifier") + } + + seq, err := strconv.ParseUint(s[2], 10, 64) + if err != nil { + return 0, sdkerrors.Wrapf(err, "failed to parse connection sequence (%s)", s[2]) + } + + return seq, nil +} + +// ParseAddressFromVersion attempts to extract the associated account address from the provided version string +func ParseAddressFromVersion(version string) (string, error) { + s := strings.Split(version, Delimiter) + if len(s) != 2 { + return "", sdkerrors.Wrap(ErrInvalidVersion, "failed to parse version") + } + + return s[1], nil +} diff --git a/modules/apps/27-interchain-accounts/types/keys_test.go b/modules/apps/27-interchain-accounts/types/keys_test.go index 037061a3d3e..0bb4c7cea84 100644 --- a/modules/apps/27-interchain-accounts/types/keys_test.go +++ b/modules/apps/27-interchain-accounts/types/keys_test.go @@ -1,6 +1,8 @@ package types_test import ( + "fmt" + "github.com/cosmos/ibc-go/v2/modules/apps/27-interchain-accounts/types" ) @@ -13,3 +15,132 @@ func (suite *TypesTestSuite) TestKeyOwnerAccount() { key := types.KeyOwnerAccount("port-id") suite.Require().Equal("owner/port-id", string(key)) } + +func (suite *TypesTestSuite) TestParseControllerConnSequence() { + + testCases := []struct { + name string + portID string + expValue uint64 + expPass bool + }{ + { + "success", + TestPortID, + 0, + true, + }, + { + "failed to parse port identifier", + "invalid-port-id", + 0, + false, + }, + { + "failed to parse connection sequence", + "ics27-1.x.y.cosmos1", + 0, + false, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + connSeq, err := types.ParseControllerConnSequence(tc.portID) + + if tc.expPass { + suite.Require().Equal(tc.expValue, connSeq) + suite.Require().NoError(err, tc.name) + } else { + suite.Require().Zero(connSeq) + suite.Require().Error(err, tc.name) + } + }) + } +} + +func (suite *TypesTestSuite) TestParseHostConnSequence() { + + testCases := []struct { + name string + portID string + expValue uint64 + expPass bool + }{ + { + "success", + TestPortID, + 0, + true, + }, + { + "failed to parse port identifier", + "invalid-port-id", + 0, + false, + }, + { + "failed to parse connection sequence", + "ics27-1.x.y.cosmos1", + 0, + false, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + connSeq, err := types.ParseHostConnSequence(tc.portID) + + if tc.expPass { + suite.Require().Equal(tc.expValue, connSeq) + suite.Require().NoError(err, tc.name) + } else { + suite.Require().Zero(connSeq) + suite.Require().Error(err, tc.name) + } + }) + } +} + +func (suite *TypesTestSuite) TestParseAddressFromVersion() { + + testCases := []struct { + name string + version string + expValue string + expPass bool + }{ + { + "success", + types.NewAppVersion(types.VersionPrefix, TestOwnerAddress), + TestOwnerAddress, + true, + }, + { + "failed to parse address from version", + "invalid-version-string", + "", + false, + }, + { + "failure with multiple delimiters", + fmt.Sprint(types.NewAppVersion(types.VersionPrefix, TestOwnerAddress), types.Delimiter, types.NewAppVersion(types.VersionPrefix, TestOwnerAddress)), + "", + false, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + addr, err := types.ParseAddressFromVersion(tc.version) + + if tc.expPass { + suite.Require().Equal(tc.expValue, addr) + suite.Require().NoError(err, tc.name) + } else { + suite.Require().Empty(addr) + suite.Require().Error(err, tc.name) + } + }) + } +}