Skip to content

Commit

Permalink
refactor: acknowledgePacket handling flushing / flush complete state (#…
Browse files Browse the repository at this point in the history
…4412)

* refactor: handle flush complete channel state transition in acknowledgePacket

* wip: adding testcases for acknowledgePacket with flushing state

* lint: make lint-fix

* test: adding assertFn temporarily to tests to provide after test state checks

* set counterparty upgrade on write try fn
  • Loading branch information
damiannolan authored Aug 23, 2023
1 parent 5c12378 commit f1e8ae8
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 9 deletions.
2 changes: 1 addition & 1 deletion modules/core/04-channel/keeper/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func emitChannelUpgradeTimeoutEvent(ctx sdk.Context, portID string, channelID st
func emitErrorReceiptEvent(ctx sdk.Context, portID string, channelID string, currentChannel types.Channel, upgradeFields types.UpgradeFields, err error) {
ctx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
types.EventTypeChannelUpgradeInit,
types.EventTypeChannelUpgradeInit, // TODO(bug): use correct const value
sdk.NewAttribute(types.AttributeKeyPortID, portID),
sdk.NewAttribute(types.AttributeKeyChannelID, channelID),
sdk.NewAttribute(types.AttributeCounterpartyPortID, currentChannel.Counterparty.PortId),
Expand Down
30 changes: 25 additions & 5 deletions modules/core/04-channel/keeper/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,6 @@ func (k Keeper) AcknowledgePacket(
// Delete packet commitment, since the packet has been acknowledged, the commitement is no longer necessary
k.deletePacketCommitment(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence())

if channel.FlushStatus == types.FLUSHING && !k.HasInflightPackets(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
channel.FlushStatus = types.FLUSHCOMPLETE
k.SetChannel(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), channel)
}

// log that a packet has been acknowledged
k.Logger(ctx).Info(
"packet acknowledged",
Expand All @@ -486,5 +481,30 @@ func (k Keeper) AcknowledgePacket(
// emit an event marking that we have processed the acknowledgement
emitAcknowledgePacketEvent(ctx, packet, channel)

// if an upgrade is in progress, handling packet flushing and update channel state appropriately
if channel.State == types.STATE_FLUSHING {
counterpartyUpgrade, found := k.GetCounterpartyUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel())
if !found {
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "counterparty upgrade not found for channel: %s", packet.GetSourceChannel())
}

timeout := counterpartyUpgrade.Timeout
// if the timeout is valid then use it, otherwise it has not been set in the upgrade handshake yet.
if timeout.IsValid() {
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
// packet flushing timeout has expired, abort the upgrade and return nil,
// committing an error receipt to state, restoring the channel and successfully acknowledging the packet.
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), err)
return nil
}

// set the channel state to flush complete if all packets have been acknowledged/flushed.
if !k.HasInflightPackets(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
channel.State = types.STATE_FLUSHCOMPLETE
k.SetChannel(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), channel)
}
}
}

return nil
}
89 changes: 89 additions & 0 deletions modules/core/04-channel/keeper/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ func (suite *KeeperTestSuite) TestAcknowledgePacket() {
packet types.Packet
ack = ibcmock.MockAcknowledgement

assertFn func()
channelCap *capabilitytypes.Capability
expError *errorsmod.Error
)
Expand Down Expand Up @@ -784,6 +785,90 @@ func (suite *KeeperTestSuite) TestAcknowledgePacket() {
channel.State = types.STATE_FLUSHING

path.EndpointA.SetChannel(channel)

counterpartyUpgrade := types.Upgrade{
Timeout: types.NewTimeout(clienttypes.ZeroHeight(), 0),
}

path.EndpointA.SetChannelCounterpartyUpgrade(counterpartyUpgrade)

assertFn = func() {
channel := path.EndpointA.GetChannel()
suite.Require().Equal(types.STATE_FLUSHING, channel.State)
}
}, true},
{"success on channel in flushing state with valid timeout", func() {
// setup uses an UNORDERED channel
suite.coordinator.Setup(path)

// create packet commitment
sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)

// create packet receipt and acknowledgement
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
err = path.EndpointB.RecvPacket(packet)
suite.Require().NoError(err)

channelCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)

channel := path.EndpointA.GetChannel()
channel.State = types.STATE_FLUSHING

path.EndpointA.SetChannel(channel)

counterpartyUpgrade := types.Upgrade{
Timeout: types.NewTimeout(suite.chainB.GetTimeoutHeight(), 0),
}

path.EndpointA.SetChannelCounterpartyUpgrade(counterpartyUpgrade)

assertFn = func() {
channel := path.EndpointA.GetChannel()
suite.Require().Equal(types.STATE_FLUSHCOMPLETE, channel.State)
}
}, true},
{"success on channel in flushing state with timeout passed", func() {
// setup uses an UNORDERED channel
suite.coordinator.Setup(path)

// create packet commitment
sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)

// create packet receipt and acknowledgement
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
err = path.EndpointB.RecvPacket(packet)
suite.Require().NoError(err)

channelCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)

channel := path.EndpointA.GetChannel()
channel.State = types.STATE_FLUSHING

path.EndpointA.SetChannel(channel)

upgrade := types.Upgrade{
Fields: types.NewUpgradeFields(types.UNORDERED, []string{ibctesting.FirstConnectionID}, ibcmock.UpgradeVersion),
Timeout: types.NewTimeout(clienttypes.ZeroHeight(), 1),
}

counterpartyUpgrade := types.Upgrade{
Fields: types.NewUpgradeFields(types.UNORDERED, []string{ibctesting.FirstConnectionID}, ibcmock.UpgradeVersion),
Timeout: types.NewTimeout(clienttypes.ZeroHeight(), 1),
}

path.EndpointA.SetChannelUpgrade(upgrade)
path.EndpointA.SetChannelCounterpartyUpgrade(counterpartyUpgrade)

assertFn = func() {
channel := path.EndpointA.GetChannel()
suite.Require().Equal(types.OPEN, channel.State)

errorReceipt, found := path.EndpointA.Chain.App.GetIBCKeeper().ChannelKeeper.GetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)
suite.Require().NotEmpty(errorReceipt)
}
}, true},
{"packet already acknowledged ordered channel (no-op)", func() {
expError = types.ErrNoOpMsg
Expand Down Expand Up @@ -1040,6 +1125,10 @@ func (suite *KeeperTestSuite) TestAcknowledgePacket() {
} else {
suite.Require().Equal(uint64(1), sequenceAck, "sequence incremented for UNORDERED channel")
}

if assertFn != nil {
assertFn()
}
} else {
suite.Error(err)
// only check if expError is set, since not all error codes can be known
Expand Down
5 changes: 4 additions & 1 deletion modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func (k Keeper) ChanUpgradeTry(

// WriteUpgradeTryChannel writes the channel end and upgrade to state after successfully passing the UpgradeTry handshake step.
// An event is emitted for the handshake step.
func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string, upgrade types.Upgrade, upgradeVersion string, counterpartyLastSequenceSend uint64) (types.Channel, types.Upgrade) {
func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string, upgrade types.Upgrade, upgradeVersion string, counterpartyUpgradeFields types.UpgradeFields) (types.Channel, types.Upgrade) {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-try")

channel, found := k.GetChannel(ctx, portID, channelID)
Expand All @@ -197,6 +197,9 @@ func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string
upgrade.Fields.Version = upgradeVersion
k.SetUpgrade(ctx, portID, channelID, upgrade)

counterpartyUpgrade := types.Upgrade{Fields: counterpartyUpgradeFields}
k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", types.OPEN, "new-state", channel.State)
emitChannelUpgradeTryEvent(ctx, portID, channelID, channel, upgrade)

Expand Down
3 changes: 2 additions & 1 deletion modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ func (suite *KeeperTestSuite) TestWriteUpgradeTry() {
path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
proposedUpgrade = path.EndpointB.GetProposedUpgrade()

Expand All @@ -361,7 +362,7 @@ func (suite *KeeperTestSuite) TestWriteUpgradeTry() {
path.EndpointB.ChannelID,
proposedUpgrade,
proposedUpgrade.Fields.Version,
proposedUpgrade.LatestSequenceSend,
path.EndpointA.GetProposedUpgrade().Fields,
)

channel := path.EndpointB.GetChannel()
Expand Down
2 changes: 1 addition & 1 deletion modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
return nil, err
}

channel, upgrade := k.ChannelKeeper.WriteUpgradeTryChannel(ctx, msg.PortId, msg.ChannelId, upgrade, upgradeVersion, upgrade.LatestSequenceSend)
channel, upgrade := k.ChannelKeeper.WriteUpgradeTryChannel(ctx, msg.PortId, msg.ChannelId, upgrade, upgradeVersion, msg.CounterpartyUpgradeFields)

ctx.Logger().Info("channel upgrade try succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)

Expand Down
5 changes: 5 additions & 0 deletions testing/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,11 @@ func (endpoint *Endpoint) SetChannelUpgrade(upgrade channeltypes.Upgrade) {
endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.SetUpgrade(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, upgrade)
}

// SetChannelCounterpartyUpgrade sets the channel counterparty upgrade for this endpoint.
func (endpoint *Endpoint) SetChannelCounterpartyUpgrade(upgrade channeltypes.Upgrade) {
endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.SetCounterpartyUpgrade(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, upgrade)
}

// QueryClientStateProof performs and abci query for a client stat associated
// with this endpoint and returns the ClientState along with the proof.
func (endpoint *Endpoint) QueryClientStateProof() (exported.ClientState, []byte) {
Expand Down

0 comments on commit f1e8ae8

Please sign in to comment.