Skip to content

Commit

Permalink
fix: corrected mistake when retrieving the reward address (#848)
Browse files Browse the repository at this point in the history
Co-authored-by: Amir Babazadeh <[email protected]>
  • Loading branch information
b00f and amirvalhalla authored Dec 11, 2023
1 parent 0076a02 commit 50d3af1
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 110 deletions.
53 changes: 34 additions & 19 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,33 +353,37 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string,
if err != nil {
return nil, nil, err
}
allValAddrs := walletInstance.AllValidatorAddresses()

if len(allValAddrs) < 1 || len(allValAddrs) > 32 {
return nil, nil, fmt.Errorf("number of validators must be between 1 and 32, but it's %d",
len(allValAddrs))
valAddrsInfo := walletInstance.AllValidatorAddresses()
if len(valAddrsInfo) == 0 {
return nil, nil, fmt.Errorf("no validator addresses found in the wallet")
}

if len(valAddrsInfo) > 32 {
PrintWarnMsgf("wallet has more than 32 validator addresses, only the first 32 will be used")
valAddrsInfo = valAddrsInfo[:32]
}

if len(conf.Node.RewardAddresses) > 0 &&
len(conf.Node.RewardAddresses) != len(allValAddrs) {
return nil, nil, fmt.Errorf("reward addresses should be %v", len(allValAddrs))
len(conf.Node.RewardAddresses) != len(valAddrsInfo) {
return nil, nil, fmt.Errorf("reward addresses should be %v", len(valAddrsInfo))
}

validatorAddrs := make([]string, len(allValAddrs))
for i := 0; i < len(validatorAddrs); i++ {
valAddr, _ := crypto.AddressFromString(allValAddrs[i].Address)
valAddrs := make([]string, len(valAddrsInfo))
for i := 0; i < len(valAddrs); i++ {
valAddr, _ := crypto.AddressFromString(valAddrsInfo[i].Address)
if !valAddr.IsValidatorAddress() {
return nil, nil, fmt.Errorf("invalid validator address: %s", allValAddrs[i].Address)
return nil, nil, fmt.Errorf("invalid validator address: %s", valAddrsInfo[i].Address)
}
validatorAddrs[i] = valAddr.String()
valAddrs[i] = valAddr.String()
}

valKeys := make([]*bls.ValidatorKey, len(allValAddrs))
valKeys := make([]*bls.ValidatorKey, len(valAddrsInfo))
password, ok := passwordFetcher(walletInstance)
if !ok {
return nil, nil, fmt.Errorf("aborted")
}
prvKeys, err := walletInstance.PrivateKeys(password, validatorAddrs)
prvKeys, err := walletInstance.PrivateKeys(password, valAddrs)
if err != nil {
return nil, nil, err
}
Expand All @@ -388,28 +392,39 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string,
}

// Create reward addresses
rewardAddrs := make([]crypto.Address, 0, len(allValAddrs))
rewardAddrs := make([]crypto.Address, 0, len(valAddrsInfo))
if len(conf.Node.RewardAddresses) != 0 {
for _, addrStr := range conf.Node.RewardAddresses {
addr, _ := crypto.AddressFromString(addrStr)
rewardAddrs = append(rewardAddrs, addr)
}
} else {
for i := 0; i < len(allValAddrs); i++ {
valAddrPath, _ := addresspath.NewPathFromString(allValAddrs[i].Path)
accAddrPath := addresspath.NewPath(valAddrPath.Purpose(), valAddrPath.CoinType(),
uint32(crypto.AddressTypeValidator)+hdkeychain.HardenedKeyStart, valAddrPath.AddressIndex())
for i := 0; i < len(valAddrsInfo); i++ {
valAddrPath, _ := addresspath.NewPathFromString(valAddrsInfo[i].Path)
accAddrPath := addresspath.NewPath(
valAddrPath.Purpose(),
valAddrPath.CoinType(),
uint32(crypto.AddressTypeBLSAccount)+hdkeychain.HardenedKeyStart,
valAddrPath.AddressIndex())

addrInfo := walletInstance.AddressFromPath(accAddrPath.String())
if addrInfo == nil {
return nil, nil, fmt.Errorf("unable to find reward address for: %s", allValAddrs[i].Address)
return nil, nil, fmt.Errorf("unable to find reward address for: %s [%s]",
valAddrsInfo[i].Address, accAddrPath)
}

addr, _ := crypto.AddressFromString(addrInfo.Address)
rewardAddrs = append(rewardAddrs, addr)
}
}

// Check if reward addresses are account address
for _, addr := range rewardAddrs {
if !addr.IsAccountAddress() {
return nil, nil, fmt.Errorf("reward address is not an account address: %s", addr)
}
}

nodeInstance, err := node.NewNode(gen, conf, valKeys, rewardAddrs)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 3 additions & 0 deletions wallet/addresspath/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func NewPath(indexes ...uint32) Path {
return p
}

// TODO: check the path should exactly 4 levels.
func NewPathFromString(str string) (Path, error) {
sub := strings.Split(str, "/")
if sub[0] != "m" {
Expand Down Expand Up @@ -52,6 +53,8 @@ func (p Path) String() string {
return builder.String()
}

// TODO: we can add IsBLSPurpose or IsImportedPurpose functions

func (p Path) Purpose() uint32 {
return p[0]
}
Expand Down
7 changes: 7 additions & 0 deletions wallet/vault/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package vault

import (
"github.com/pactus-project/pactus/crypto/bls/hdkeychain"
"github.com/tyler-smith/go-bip39"
"golang.org/x/exp/constraints"
)

// GenerateMnemonic generates a new mnemonic (seed phrase) based on BIP-39
Expand All @@ -22,3 +24,8 @@ func CheckMnemonic(mnemonic string) error {
_, err := bip39.EntropyFromMnemonic(mnemonic)
return err
}

// H hardens the value 'i' by adding it to 0x80000000 (2^31).
func H[T constraints.Integer](i T) uint32 {
return uint32(i) + hdkeychain.HardenedKeyStart
}
72 changes: 27 additions & 45 deletions wallet/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,8 @@ type AddressInfo struct {
}

const (
PurposeBLS12381 = uint32(12381)
HardenedPurposeBLS12381 = PurposeBLS12381 + hdkeychain.HardenedKeyStart
PurposeImportPrivateKey = uint32(65535)
HardenedPurposeImportPrivateKey = PurposeImportPrivateKey + hdkeychain.HardenedKeyStart
HardenedAddressTypeBLSAccount = uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart
HardenedAddressTypeValidator = uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart
PurposeBLS12381 = uint32(12381)
PurposeImportPrivateKey = uint32(65535)
)

type Vault struct {
Expand Down Expand Up @@ -97,18 +93,18 @@ func CreateVaultFromMnemonic(mnemonic string, coinType uint32) (*Vault, error) {
enc := encrypter.NopeEncrypter()

xPubValidator, err := masterKey.DerivePath([]uint32{
12381 + hdkeychain.HardenedKeyStart,
coinType + hdkeychain.HardenedKeyStart,
uint32(crypto.AddressTypeValidator) + hdkeychain.HardenedKeyStart,
H(PurposeBLS12381),
H(coinType),
H(crypto.AddressTypeValidator),
})
if err != nil {
return nil, err
}

xPubAccount, err := masterKey.DerivePath([]uint32{
12381 + hdkeychain.HardenedKeyStart,
coinType + hdkeychain.HardenedKeyStart,
uint32(crypto.AddressTypeBLSAccount) + hdkeychain.HardenedKeyStart,
H(PurposeBLS12381),
H(coinType),
H(crypto.AddressTypeBLSAccount),
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -222,22 +218,7 @@ func (v *Vault) AllValidatorAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeValidator) {
addrs = append(addrs, addrInfo)
}
}

v.SortAddressesByAddressIndex(addrs...)
v.SortAddressesByPurpose(addrs...)

return addrs
}

func (v *Vault) AllBLSAccountAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.AddressType()-hdkeychain.HardenedKeyStart == uint32(crypto.AddressTypeBLSAccount) {
if addrPath.AddressType() == H(crypto.AddressTypeValidator) {
addrs = append(addrs, addrInfo)
}
}
Expand All @@ -252,7 +233,7 @@ func (v *Vault) AllImportedPrivateKeysAddresses() []AddressInfo {
addrs := make([]AddressInfo, 0, v.AddressCount()/2)
for _, addrInfo := range v.Addresses {
addrPath, _ := addresspath.NewPathFromString(addrInfo.Path)
if addrPath.Purpose() == HardenedPurposeImportPrivateKey {
if addrPath.Purpose() == H(PurposeImportPrivateKey) {
addrs = append(addrs, addrInfo)
}
}
Expand Down Expand Up @@ -331,17 +312,17 @@ func (v *Vault) ImportPrivateKey(password string, prv *bls.PrivateKey) error {
return ErrAddressExists
}

blsAccPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey,
v.CoinType+hdkeychain.HardenedKeyStart,
HardenedAddressTypeBLSAccount,
uint32(addressIndex)+hdkeychain.HardenedKeyStart).
String()
blsAccPathStr := addresspath.NewPath(
H(PurposeImportPrivateKey),
H(v.CoinType),
H(crypto.AddressTypeBLSAccount),
H(addressIndex)).String()

blsValidatorPathStr := addresspath.NewPath(HardenedPurposeImportPrivateKey,
v.CoinType+addresspath.HardenedKeyStart,
HardenedAddressTypeValidator,
uint32(addressIndex)+hdkeychain.HardenedKeyStart).
String()
blsValidatorPathStr := addresspath.NewPath(
H(PurposeImportPrivateKey),
H(v.CoinType),
H(crypto.AddressTypeValidator),
H(addressIndex)).String()

importedPrvLabelCounter := (len(v.AllImportedPrivateKeysAddresses()) / 2) + 1
v.Addresses[accAddr.String()] = AddressInfo{
Expand Down Expand Up @@ -390,12 +371,12 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe
return nil, err
}

if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType {
if path.CoinType() != H(v.CoinType) {
return nil, ErrInvalidCoinType
}

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
seed, err := bip39.NewSeedWithErrorChecking(keyStore.MasterNode.Mnemonic, "")
if err != nil {
return nil, err
Expand All @@ -419,8 +400,9 @@ func (v *Vault) PrivateKeys(password string, addrs []string) ([]crypto.PrivateKe
}

keys[i] = prvKey
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
index := path.AddressIndex() - hdkeychain.HardenedKeyStart
// TODO: index out of range check
str := keyStore.ImportedKeys[index]
prv, err := bls.PrivateKeyFromString(str)
if err != nil {
Expand Down Expand Up @@ -504,12 +486,12 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo {
}

// TODO it would be better to return error in future
if path.CoinType()-hdkeychain.HardenedKeyStart != v.CoinType {
if path.CoinType() != H(v.CoinType) {
return nil
}

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
addr, err := crypto.AddressFromString(info.Address)
if err != nil {
return nil
Expand Down Expand Up @@ -543,7 +525,7 @@ func (v *Vault) AddressInfo(addr string) *AddressInfo {
}

info.PublicKey = blsPubKey.String()
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
default:
return nil
}
Expand Down
42 changes: 4 additions & 38 deletions wallet/vault/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestAddressInfo(t *testing.T) {
path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
if addr.IsValidatorAddress() {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
Expand All @@ -94,7 +94,7 @@ func TestAddressInfo(t *testing.T) {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
}
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
if addr.IsValidatorAddress() {
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
Expand Down Expand Up @@ -135,10 +135,10 @@ func TestAllValidatorAddresses(t *testing.T) {
path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
case H(PurposeBLS12381):
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
case HardenedPurposeImportPrivateKey:
case H(PurposeImportPrivateKey):
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/1'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
}
Expand All @@ -155,40 +155,6 @@ func TestSortAllValidatorAddresses(t *testing.T) {
assert.Equal(t, "m/65535'/21888'/1'/0'", validatorAddrs[len(validatorAddrs)-1].Path)
}

func TestAllBLSAccountAddresses(t *testing.T) {
td := setup(t)

assert.Equal(t, td.vault.AddressCount(), 6)

blsAccountAddrs := td.vault.AllBLSAccountAddresses()
for _, i := range blsAccountAddrs {
info := td.vault.AddressInfo(i.Address)
assert.Equal(t, i.Address, info.Address)

path, _ := addresspath.NewPathFromString(info.Path)

switch path.Purpose() {
case HardenedPurposeBLS12381:
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d",
PurposeBLS12381, td.vault.CoinType, path.AddressIndex()))
case HardenedPurposeImportPrivateKey:
assert.Equal(t, info.Path, fmt.Sprintf("m/%d'/%d'/2'/%d'",
PurposeImportPrivateKey, td.vault.CoinType, path.AddressIndex()-hdkeychain.HardenedKeyStart))
}
}
}

func TestSortAllBLSAccountAddresses(t *testing.T) {
td := setup(t)

assert.Equal(t, td.vault.AddressCount(), 6)

blsAccountAddrs := td.vault.AllBLSAccountAddresses()

assert.Equal(t, "m/12381'/21888'/2'/0", blsAccountAddrs[0].Path)
assert.Equal(t, "m/65535'/21888'/2'/0'", blsAccountAddrs[len(blsAccountAddrs)-1].Path)
}

func TestAddressFromPath(t *testing.T) {
td := setup(t)
assert.Equal(t, td.vault.AddressCount(), 6)
Expand Down
25 changes: 17 additions & 8 deletions wallet/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,25 @@ func TestInvalidAddress(t *testing.T) {
assert.Error(t, err)
}

// TODO: Fix later.
// func TestImportPrivateKey(t *testing.T) {
// td := setup(t)
func TestImportPrivateKey(t *testing.T) {
td := setup(t)

_, prv := td.RandBLSKeyPair()
assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv))

pub := prv.PublicKeyNative()
accAddr := pub.AccountAddress().String()
valAddr := pub.AccountAddress().String()

// _, prv := td.RandBLSKeyPair()
// assert.NoError(t, td.wallet.ImportPrivateKey(td.password, prv))
assert.True(t, td.wallet.Contains(accAddr))
assert.True(t, td.wallet.Contains(valAddr))

// addr := prv.PublicKeyNative().AccountAddress().String()
// assert.True(t, td.wallet.Contains(addr))
// }
accAddrInfo := td.wallet.AddressInfo(accAddr)
valAddrInfo := td.wallet.AddressInfo(accAddr)

assert.Equal(t, pub.String(), accAddrInfo.PublicKey)
assert.Equal(t, pub.String(), valAddrInfo.PublicKey)
}

func TestKeyInfo(t *testing.T) {
td := setup(t)
Expand Down

0 comments on commit 50d3af1

Please sign in to comment.