Skip to content

Commit

Permalink
Fixes staking hooks safety issues (#12578)
Browse files Browse the repository at this point in the history
  • Loading branch information
danwt authored Jul 17, 2022
1 parent 63238a3 commit 7ec4e9e
Show file tree
Hide file tree
Showing 8 changed files with 834 additions and 781 deletions.
3 changes: 2 additions & 1 deletion docs/core/proto-docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -6613,7 +6613,8 @@ multiplied by exchange rate.
| `unbonding_time` | [google.protobuf.Timestamp](#google.protobuf.Timestamp) | | unbonding_time defines, if unbonding, the min time for the validator to complete unbonding. |
| `commission` | [Commission](#cosmos.staking.v1beta1.Commission) | | commission defines the commission parameters. |
| `min_self_delegation` | [string](#string) | | min_self_delegation is the validator's self declared minimum self delegation. |
| `unbonding_on_hold` | [bool](#bool) | | True if this validator's unbonding has been stopped by an external module |
| `unbonding_on_hold` | [bool](#bool) | | false iff unbonding is allowed to complete in staking EndBlock |
| `unbonding_id` | [uint64](#uint64) | | unique id, used to distinguish unbond->rebond->unbond executions |



Expand Down
5 changes: 3 additions & 2 deletions proto/cosmos/staking/v1beta1/staking.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ message Validator {
(gogoproto.customtype) = "github.com/cosmos/cosmos-sdk/types.Int",
(gogoproto.nullable) = false
];

// True if this validator's unbonding has been stopped by an external module
// false iff unbonding is allowed to complete in staking EndBlock
bool unbonding_on_hold = 12;
// unique id, used to distinguish unbond->rebond->unbond executions
uint64 unbonding_id = 13;
}

// BondStatus is the status of a validator.
Expand Down
6 changes: 3 additions & 3 deletions x/staking/keeper/delegation.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,9 @@ func (k Keeper) Unbond(
return amount, nil
}

// getBeginInfo returns the completion time and height of a redelegation, along
// with a boolean signaling if the redelegation is complete based on the source
// validator.
// getBeginInfo returns the completion time and creation height of a
// redelegation, along with a boolean signaling if the redelegation is
// complete based on the source validator.
func (k Keeper) getBeginInfo(
ctx sdk.Context, valSrcAddr sdk.ValAddress,
) (completionTime time.Time, height int64, completeNow bool) {
Expand Down
19 changes: 15 additions & 4 deletions x/staking/keeper/unbonding.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,33 @@ func (k Keeper) redelegationEntryCanComplete(ctx sdk.Context, id uint64) (found
return true, nil
}

// WARNING: precondition:
// Safety only guaranteed if this method is called AFTER staking.EndBlock
func (k Keeper) validatorUnbondingCanComplete(ctx sdk.Context, id uint64) (found bool, err error) {
val, found := k.GetValidatorByUnbondingId(ctx, id)

if !found {
// validator can never be deleted before unbonding
// even if it is slashed to 0, so we always expect
// to find it.
return false, nil
}

if !val.IsMature(ctx.BlockTime(), ctx.BlockHeight()) {
if val.UnbondingId != id {
// validator already rebonded
return true, nil
}

if val.UnbondingTime.After(ctx.BlockTime()) {
// validator cannot have already been dequeued by EndBlock
val.UnbondingOnHold = false
k.SetValidator(ctx, val)
} else {
// If unbonding is mature complete it
val = k.UnbondingToUnbonded(ctx, val)
// Validator is mature. Unbond it.
if val.GetDelegatorShares().IsZero() {
k.RemoveValidator(ctx, val.GetOperator())
}

val = k.UnbondingToUnbonded(ctx, val)
k.DeleteUnbondingIndex(ctx, id)
}

Expand Down
31 changes: 21 additions & 10 deletions x/staking/keeper/val_state_change.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"sort"
"time"

gogotypes "github.com/gogo/protobuf/types"
abci "github.com/tendermint/tendermint/abci/types"
Expand Down Expand Up @@ -273,7 +274,7 @@ func (k Keeper) unbondedToBonded(ctx sdk.Context, validator types.Validator) (ty
// UnbondingToUnbonded switches a validator from unbonding state to unbonded state
func (k Keeper) UnbondingToUnbonded(ctx sdk.Context, validator types.Validator) types.Validator {
if !validator.IsUnbonding() {
panic(fmt.Sprintf("bad state transition unbondingToBonded, validator: %v\n", validator))
panic(fmt.Sprintf("bad state transition unbondingToUnbonded, validator: %v\n", validator))
}

return k.CompleteUnbondingValidator(ctx, validator)
Expand Down Expand Up @@ -306,15 +307,19 @@ func (k Keeper) bondValidator(ctx sdk.Context, validator types.Validator) (types
// delete the validator by power index, as the key will change
k.DeleteValidatorByPowerIndex(ctx, validator)

// delete from queue if present
k.DeleteValidatorQueue(ctx, validator, validator.UnbondingTime, validator.UnbondingHeight)

validator = validator.UpdateStatus(types.Bonded)
validator.UnbondingHeight = 0
validator.UnbondingTime = time.Time{}
validator.UnbondingOnHold = false
validator.UnbondingId = 0

// save the now bonded validator record to the two referenced stores
k.SetValidator(ctx, validator)
k.SetValidatorByPowerIndex(ctx, validator)

// delete from queue if present
k.DeleteValidatorQueue(ctx, validator)

// trigger hook
consAddr, err := validator.GetConsAddr()
if err != nil {
Expand All @@ -337,27 +342,29 @@ func (k Keeper) BeginUnbondingValidator(ctx sdk.Context, validator types.Validat
panic(fmt.Sprintf("should not already be unbonded or unbonding, validator: %v\n", validator))
}

consAddr, err := validator.GetConsAddr()
if err != nil {
return validator, err
}

validator = validator.UpdateStatus(types.Unbonding)

// set the unbonding completion time and completion height appropriately
validator.UnbondingTime = ctx.BlockHeader().Time.Add(params.UnbondingTime)
validator.UnbondingHeight = ctx.BlockHeader().Height

id := k.IncrementUnbondingId(ctx)
validator.UnbondingId = id

// save the now unbonded validator record and power index
k.SetValidator(ctx, validator)
k.SetValidatorByPowerIndex(ctx, validator)

// Adds to unbonding validator queue
k.InsertUnbondingValidatorQueue(ctx, validator)

// trigger hook
consAddr, err := validator.GetConsAddr()
if err != nil {
return validator, err
}
k.AfterValidatorBeginUnbonding(ctx, consAddr, validator.GetOperator())

id := k.IncrementUnbondingId(ctx)
k.SetValidatorByUnbondingIndex(ctx, validator, id)

k.AfterUnbondingInitiated(ctx, id)
Expand All @@ -368,6 +375,10 @@ func (k Keeper) BeginUnbondingValidator(ctx sdk.Context, validator types.Validat
// perform all the store operations for when a validator status becomes unbonded
func (k Keeper) CompleteUnbondingValidator(ctx sdk.Context, validator types.Validator) types.Validator {
validator = validator.UpdateStatus(types.Unbonded)
validator.UnbondingHeight = 0
validator.UnbondingTime = time.Time{}
validator.UnbondingOnHold = false
validator.UnbondingId = 0
k.SetValidator(ctx, validator)

return validator
Expand Down
17 changes: 10 additions & 7 deletions x/staking/keeper/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,9 @@ func (k Keeper) DeleteValidatorQueueTimeSlice(ctx sdk.Context, endTime time.Time

// DeleteValidatorQueue removes a validator by address from the unbonding queue
// indexed by a given height and time.
func (k Keeper) DeleteValidatorQueue(ctx sdk.Context, val types.Validator) {
addrs := k.GetUnbondingValidators(ctx, val.UnbondingTime, val.UnbondingHeight)
func (k Keeper) DeleteValidatorQueue(ctx sdk.Context, val types.Validator,
unbondingTime time.Time, unbondingHeight int64) {
addrs := k.GetUnbondingValidators(ctx, unbondingTime, unbondingHeight)
newAddrs := []string{}

for _, addr := range addrs {
Expand All @@ -378,14 +379,14 @@ func (k Keeper) DeleteValidatorQueue(ctx sdk.Context, val types.Validator) {
}

if len(newAddrs) == 0 {
k.DeleteValidatorQueueTimeSlice(ctx, val.UnbondingTime, val.UnbondingHeight)
k.DeleteValidatorQueueTimeSlice(ctx, unbondingTime, unbondingHeight)
} else {
k.SetUnbondingValidatorsQueue(ctx, val.UnbondingTime, val.UnbondingHeight, newAddrs)
k.SetUnbondingValidatorsQueue(ctx, unbondingTime, unbondingHeight, newAddrs)
}
}

// ValidatorQueueIterator returns an interator ranging over validators that are
// unbonding whose unbonding completion occurs at the given height and time.
// unbonding whose unbonding completion occurs not before a given time.
func (k Keeper) ValidatorQueueIterator(ctx sdk.Context, endTime time.Time, endHeight int64) sdk.Iterator {
store := ctx.KVStore(k.storeKey)
return store.Iterator(types.ValidatorQueueKey, sdk.InclusiveEndBytes(types.GetValidatorQueueKey(endTime, endHeight)))
Expand All @@ -409,15 +410,15 @@ func (k Keeper) UnbondAllMatureValidators(ctx sdk.Context) {

for ; unbondingValIterator.Valid(); unbondingValIterator.Next() {
key := unbondingValIterator.Key()
keyTime, keyHeight, err := types.ParseValidatorQueueKey(key)
keyTime, _, err := types.ParseValidatorQueueKey(key)
if err != nil {
panic(fmt.Errorf("failed to parse unbonding key: %w", err))
}

// All addresses for the given key have the same unbonding height and time.
// We only unbond if the height and time are less than the current height
// and time.
if keyHeight <= blockHeight && (keyTime.Before(blockTime) || keyTime.Equal(blockTime)) {
if keyTime.Before(blockTime) || keyTime.Equal(blockTime) {
addrs := types.ValAddresses{}
k.cdc.MustUnmarshal(unbondingValIterator.Value(), &addrs)

Expand All @@ -432,6 +433,8 @@ func (k Keeper) UnbondAllMatureValidators(ctx sdk.Context) {
}

if !val.IsUnbonding() {
fmt.Println("status is ", val.Status)

panic("unexpected validator in unbonding queue; status was not unbonding")
}
if !val.UnbondingOnHold {
Expand Down
Loading

0 comments on commit 7ec4e9e

Please sign in to comment.