Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: restructure timeout type #5404

Merged
merged 16 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions modules/core/04-channel/keeper/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,12 @@ func (k Keeper) AcknowledgePacket(
counterpartyUpgrade, found := k.GetCounterpartyUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel())
if found {
timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
// packet flushing timeout has expired, abort the upgrade and return nil,
// committing an error receipt to state, restoring the channel and successfully acknowledging the packet.
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), err)
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp))
return nil
}

Expand Down
7 changes: 5 additions & 2 deletions modules/core/04-channel/keeper/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types"
"github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
host "github.com/cosmos/ibc-go/v8/modules/core/24-host"
Expand Down Expand Up @@ -157,10 +158,12 @@ func (k Keeper) TimeoutExecuted(
// then we can move to flushing complete if the timeout has not passed and there are no in-flight packets
if found {
timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
// packet flushing timeout has expired, abort the upgrade and return nil,
// committing an error receipt to state, restoring the channel and successfully timing out the packet.
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), err)
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp))
} else if !k.HasInflightPackets(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
// set the channel state to flush complete if all packets have been flushed.
channel.State = types.FLUSHCOMPLETE
Expand Down
12 changes: 8 additions & 4 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ func (k Keeper) ChanUpgradeAck(
}

timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed"))
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp), "counterparty upgrade timeout elapsed"))
}

return nil
Expand Down Expand Up @@ -409,8 +411,10 @@ func (k Keeper) ChanUpgradeConfirm(
}

timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed"))
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp), "counterparty upgrade timeout elapsed"))
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
err = path.EndpointA.UpdateClient()
suite.Require().NoError(err)
},
types.NewUpgradeError(1, types.ErrInvalidUpgrade),
types.NewUpgradeError(1, types.ErrTimeoutElapsed),
},
}

Expand Down Expand Up @@ -813,7 +813,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeConfirm() {
err := path.EndpointB.UpdateClient()
suite.Require().NoError(err)
},
types.NewUpgradeError(1, types.ErrInvalidUpgrade),
types.NewUpgradeError(1, types.ErrTimeoutElapsed),
},
}

Expand Down
2 changes: 2 additions & 0 deletions modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ var (
ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 36, "pending inflight packets exist")
ErrUpgradeTimeoutFailed = errorsmod.Register(SubModuleName, 37, "upgrade timeout failed")
ErrInvalidPruningLimit = errorsmod.Register(SubModuleName, 38, "invalid pruning limit")
ErrTimeoutNotReached = errorsmod.Register(SubModuleName, 39, "timeout not reached")
ErrTimeoutElapsed = errorsmod.Register(SubModuleName, 40, "timeout elapsed")
)
52 changes: 34 additions & 18 deletions modules/core/04-channel/types/timeout.go
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I like about this now is that we have a generic timeout type rather than a timeout specific to channel upgrades

Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package types

import (
"time"

errorsmod "cosmossdk.io/errors"

sdk "github.com/cosmos/cosmos-sdk/types"

clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
)

Expand All @@ -18,27 +14,47 @@ func NewTimeout(height clienttypes.Height, timestamp uint64) Timeout {
}
}

// IsValid returns true if either the height or timestamp is non-zero
// IsValid returns true if either the height or timestamp is non-zero.
func (t Timeout) IsValid() bool {
return !t.Height.IsZero() || t.Timestamp != 0
}

// TODO: Update after https://github.com/cosmos/ibc-go/issues/3483 has been resolved
// HasPassed returns true if the upgrade has passed the timeout height or timestamp
func (t Timeout) HasPassed(ctx sdk.Context) (bool, error) {
if !t.IsValid() {
return true, errorsmod.Wrap(ErrInvalidUpgrade, "upgrade timeout cannot be empty")
}
// Elapsed returns true if either the provided height or timestamp is past the
// respective absolute timeout values.
func (t Timeout) Elapsed(height clienttypes.Height, timestamp uint64) bool {
return t.heightElapsed(height) || t.timestampElapsed(timestamp)
}

selfHeight, timeoutHeight := clienttypes.GetSelfHeight(ctx), t.Height
if selfHeight.GTE(timeoutHeight) && timeoutHeight.GT(clienttypes.ZeroHeight()) {
return true, errorsmod.Wrapf(ErrInvalidUpgrade, "block height >= upgrade timeout height (%s >= %s)", selfHeight, timeoutHeight)
// ErrTimeoutElapsed returns a timeout elapsed error indicating which timeout value
// has elapsed.
func (t Timeout) ErrTimeoutElapsed(height clienttypes.Height, timestamp uint64) error {
if t.heightElapsed(height) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to add an additional statement if folks want to indicate both height and timestamp elapsed. I just opted for this for now

return errorsmod.Wrapf(ErrTimeoutElapsed, "current height: %s, timeout height %s", height, t.Height)
}

selfTime, timeoutTimestamp := uint64(ctx.BlockTime().UnixNano()), t.Timestamp
if selfTime >= timeoutTimestamp && timeoutTimestamp > 0 {
return true, errorsmod.Wrapf(ErrInvalidUpgrade, "block timestamp >= upgrade timeout timestamp (%s >= %s)", ctx.BlockTime(), time.Unix(0, int64(timeoutTimestamp)))
return errorsmod.Wrapf(ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, t.Timestamp)
}

// ErrTimeoutNotReached returns a timeout not reached error indicating which timeout value
// has not been reached.
func (t Timeout) ErrTimeoutNotReached(height clienttypes.Height, timestamp uint64) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be ErrTimeoutNotElapsed if folks want

// only return height information if the height is set
// t.heightElapsed() will return false when it is empty
if !t.Height.IsZero() && !t.heightElapsed(height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just !t.heightElapsed(height)? Currently would expand to:

!t.Height.IsZero() && !t.Height.IsZero() && height.GTE(t.Height)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it is: !t.Height.IsZero() && !(!t.Height.IsZero() && height.GTE(t.Height))

When t.Height.IsZero(), the second statement returns !false. This caught me while writing tests as I originally just had !t.heightElapsed(height). It's quite unfortunate and also odd. When height is 0, heightElapsed returning false makes sense, but in the case of ErrTimeoutNotReached it doesn't as the height is not being used to determine the timeout

I could pull out the !IsZero() checks out of heightElapsed and timestampElapsed or I could add a comment to explain that when asserting timeout has not been reached, we must filter out the height if it is empty

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha! indeed. These are slightly tricky to read, thanks for adding a comment here!

return errorsmod.Wrapf(ErrTimeoutNotReached, "current height: %s, timeout height %s", height, t.Height)
}

return false, nil
return errorsmod.Wrapf(ErrTimeoutNotReached, "current timestamp: %d, timeout timestamp %d", timestamp, t.Timestamp)
}

// heightElapsed returns true if the timeout height is non empty
// and the timeout height is greater than or equal to the relative height.
func (t Timeout) heightElapsed(height clienttypes.Height) bool {
return !t.Height.IsZero() && height.GTE(t.Height)
}

// timestampElapsed returns true if the timeout timestamp is non empty
// and the timeout timestamp is greater than or equal to the relative timestamp.
func (t Timeout) timestampElapsed(timestamp uint64) bool {
return t.Timestamp != 0 && timestamp >= t.Timestamp
}
179 changes: 179 additions & 0 deletions modules/core/04-channel/types/timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package types_test

import (
errorsmod "cosmossdk.io/errors"

clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
"github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
)
Expand Down Expand Up @@ -57,3 +59,180 @@ func (suite *TypesTestSuite) TestIsValid() {
})
}
}

func (suite *TypesTestSuite) TestElapsed() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expElapsed bool
}{
{
"elapsed: both timeout with height and timestamp",
types.NewTimeout(height, timestamp),
true,
},
{
"elapsed: timeout with height and zero timestamp",
types.NewTimeout(height, 0),
true,
},
{
"elapsed: timeout with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp),
true,
},
{
"elapsed: height elapsed, timestamp did not",
types.NewTimeout(height, timestamp+1),
true,
},
{
"elapsed: timestamp elapsed, height did not",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp),
true,
},
{
"elapsed: timestamp elapsed when less than current timestamp",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp-1),
true,
},
{
"elapsed: height elapsed when less than current height",
types.NewTimeout(clienttypes.NewHeight(0, 1), 0),
true,
},
{
"not elapsed: invalid timeout",
types.NewTimeout(clienttypes.ZeroHeight(), 0),
false,
},
{
"not elapsed: neither height nor timeout elapsed",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp+1),
false,
},
{
"not elapsed: timeout not reached with height and zero timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), 0),
false,
},
{
"elapsed: timeout not reached with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp+1),
false,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
elapsed := tc.timeout.Elapsed(height, timestamp)
suite.Require().Equal(tc.expElapsed, elapsed)
})
}
}

func (suite *TypesTestSuite) TestErrTimeoutElapsed() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expError error
}{
{
"both timeout with height and timestamp",
types.NewTimeout(height, timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timeout with height and zero timestamp",
types.NewTimeout(height, 0),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timeout with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp),
},
{
"height elapsed, timestamp did not",
types.NewTimeout(height, timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timestamp elapsed, height did not",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp),
},
{
"height elapsed when less than current height",
types.NewTimeout(clienttypes.NewHeight(0, 1), 0),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, clienttypes.NewHeight(0, 1)),
},
{
"timestamp elapsed when less than current timestamp",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp-1),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp-1),
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
err := tc.timeout.ErrTimeoutElapsed(height, timestamp)
suite.Require().Equal(tc.expError.Error(), err.Error())
})
}
}

func (suite *TypesTestSuite) TestErrTimeoutNotReached() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expError error
}{
{
"neither timeout reached with height and timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current height: %s, timeout height %s", height, height.Increment().(clienttypes.Height)),
},
{
"timeout not reached with height and zero timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), 0),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current height: %s, timeout height %s", height, height.Increment().(clienttypes.Height)),
},
{
"timeout not reached with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp+1),
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
err := tc.timeout.ErrTimeoutNotReached(height, timestamp)
suite.Require().Equal(tc.expError.Error(), err.Error())
})
}
}
4 changes: 2 additions & 2 deletions modules/core/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1326,9 +1326,9 @@ func (suite *KeeperTestSuite) TestChannelUpgradeConfirm() {
{
"core handler returns error and writes upgrade error receipt",
func() {
// force an upgrade error by modifying the counterparty channel upgrade timeout to be no longer valid
// force an upgrade error by modifying the counterparty channel upgrade timeout to be elapsed
upgrade := path.EndpointA.GetChannelUpgrade()
upgrade.Timeout = channeltypes.NewTimeout(clienttypes.ZeroHeight(), 0)
upgrade.Timeout = channeltypes.NewTimeout(clienttypes.ZeroHeight(), uint64(path.EndpointB.Chain.CurrentHeader.Time.UnixNano()))

path.EndpointA.SetChannelUpgrade(upgrade)

Expand Down
Loading