Skip to content

Commit

Permalink
btcec/schnorr/musig2: optimize signing+verification
Browse files Browse the repository at this point in the history
In this commit, we optimize signing+verification mainly by only
computing values once, and reducing allocations when possible.

The following optimizations have been implemented:
  * Use a single buffer allocation in keyHashFingerprint to avoid
    dynamic buffer growth+re-sizing
  * Remove the isSecondKey computation and replace that with a single
    routine that computes the index of the second unique key.
  * Optimize keyHashFingerprint usage by only computing it once during
    signing +verification.

A further optimization is possible: use the x coordinate of a key for
comparisons instead of computing the full sexualision. We need to do
the latter atm, as the X() method of the public key struct will allocate
more memory as it allocate and sets the buffer in place.

The final benchmarks of before and after this commit:
benchmark                                                             old ns/op     new ns/op     delta
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=true-8        1227374       1194047       -2.72%
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=false-8       1217743       1191468       -2.16%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=true-8       2755544       2698827       -2.06%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=false-8      2754749       2694547       -2.19%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=true-8       12382654      10561204      -14.71%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=false-8      12260134      10315376      -15.86%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=true-8      24832061      22009935      -11.36%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=false-8     24650086      21022833      -14.71%
BenchmarkPartialVerify/sort_keys=true/num_signers=10-8                1485787       1473377       -0.84%
BenchmarkPartialVerify/sort_keys=false/num_signers=10-8               1447275       1465139       +1.23%
BenchmarkPartialVerify/sort_keys=true/num_signers=100-8               12503482      10672618      -14.64%
BenchmarkPartialVerify/sort_keys=false/num_signers=100-8              12388289      10581398      -14.59%
BenchmarkCombineSigs/num_signers=10-8                                 0.00          0.00          +0.00%
BenchmarkCombineSigs/num_signers=100-8                                0.00          0.00          -1.95%
BenchmarkAggregateNonces/num_signers=10-8                             0.00          0.00          -0.76%
BenchmarkAggregateNonces/num_signers=100-8                            0.00          0.00          +1.13%
BenchmarkAggregateKeys/num_signers=10/sort_keys=true-8                0.00          0.00          -0.09%
BenchmarkAggregateKeys/num_signers=10/sort_keys=false-8               0.00          0.01          +559.94%
BenchmarkAggregateKeys/num_signers=100/sort_keys=true-8               0.01          0.01          -11.30%
BenchmarkAggregateKeys/num_signers=100/sort_keys=false-8              0.01          0.01          -11.66%

benchmark                                                             old allocs     new allocs     delta
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=true-8        458            269            -41.27%
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=false-8       409            222            -45.72%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=true-8       892            524            -41.26%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=false-8      841            467            -44.47%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=true-8       14366          3089           -78.50%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=false-8      13143          1842           -85.98%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=true-8      27596          4964           -82.01%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=false-8     26309          3707           -85.91%
BenchmarkPartialVerify/sort_keys=true/num_signers=10-8                430            243            -43.49%
BenchmarkPartialVerify/sort_keys=false/num_signers=10-8               430            243            -43.49%
BenchmarkPartialVerify/sort_keys=true/num_signers=100-8               13164          1863           -85.85%
BenchmarkPartialVerify/sort_keys=false/num_signers=100-8              13164          1863           -85.85%
BenchmarkCombineSigs/num_signers=10-8                                 0              0              +0.00%
BenchmarkCombineSigs/num_signers=100-8                                0              0              +0.00%
BenchmarkAggregateNonces/num_signers=10-8                             0              0              +0.00%
BenchmarkAggregateNonces/num_signers=100-8                            0              0              +0.00%
BenchmarkAggregateKeys/num_signers=10/sort_keys=true-8                0              0              +0.00%
BenchmarkAggregateKeys/num_signers=10/sort_keys=false-8               0              0              +0.00%
BenchmarkAggregateKeys/num_signers=100/sort_keys=true-8               0              0              +0.00%
BenchmarkAggregateKeys/num_signers=100/sort_keys=false-8              0              0              +0.00%

benchmark                                                             old bytes     new bytes     delta
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=true-8        27854         14878         -46.59%
BenchmarkPartialSign/num_signers=10/fast_sign=true/sort=false-8       25508         12605         -50.58%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=true-8       54982         29476         -46.39%
BenchmarkPartialSign/num_signers=10/fast_sign=false/sort=false-8      52581         26805         -49.02%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=true-8       1880138       166996        -91.12%
BenchmarkPartialSign/num_signers=100/fast_sign=true/sort=false-8      1820561       106295        -94.16%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=true-8      3706291       275344        -92.57%
BenchmarkPartialSign/num_signers=100/fast_sign=false/sort=false-8     3642725       214122        -94.12%
BenchmarkPartialVerify/sort_keys=true/num_signers=10-8                26995         14078         -47.85%
BenchmarkPartialVerify/sort_keys=false/num_signers=10-8               26980         14078         -47.82%
BenchmarkPartialVerify/sort_keys=true/num_signers=100-8               1822043       107767        -94.09%
BenchmarkPartialVerify/sort_keys=false/num_signers=100-8              1822046       107752        -94.09%
BenchmarkCombineSigs/num_signers=10-8                                 0             0             +0.00%
BenchmarkCombineSigs/num_signers=100-8                                0             0             +0.00%
BenchmarkAggregateNonces/num_signers=10-8                             0             0             +0.00%
BenchmarkAggregateNonces/num_signers=100-8                            0             0             +0.00%
BenchmarkAggregateKeys/num_signers=10/sort_keys=true-8                0             0             +0.00%
BenchmarkAggregateKeys/num_signers=10/sort_keys=false-8               0             0             +0.00%
BenchmarkAggregateKeys/num_signers=100/sort_keys=true-8               0             0             +0.00%
BenchmarkAggregateKeys/num_signers=100/sort_keys=false-8              0             0             +0.00%
  • Loading branch information
Roasbeef committed Mar 15, 2022
1 parent d57f4d3 commit 1186a89
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 44 deletions.
7 changes: 6 additions & 1 deletion btcec/schnorr/musig2/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,16 @@ func BenchmarkAggregateKeys(b *testing.B) {
name := fmt.Sprintf("num_signers=%v/sort_keys=%v",
numSigners, sortKeys)

uniqueKeyIndex := secondUniqueKeyIndex(signerKeys)

b.Run(name, func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()

aggKey := AggregateKeys(signerKeys, sortKeys)
aggKey := AggregateKeys(
signerKeys, sortKeys,
WithUniqueKeyIndex(uniqueKeyIndex),
)

testKey = aggKey
})
Expand Down
134 changes: 99 additions & 35 deletions btcec/schnorr/musig2/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type sortableKeys []*btcec.PublicKey
// Less reports whether the element with index i must sort before the element
// with index j.
func (s sortableKeys) Less(i, j int) bool {
// TODO(roasbeef): more efficient way to compare...
keyIBytes := schnorr.SerializePubKey(s[i])
keyJBytes := schnorr.SerializePubKey(s[j])

Expand Down Expand Up @@ -62,12 +63,14 @@ func sortKeys(keys []*btcec.PublicKey) []*btcec.PublicKey {
// for each key. The final computation is:
// * H(tag=KeyAgg list, pk1 || pk2..)
func keyHashFingerprint(keys []*btcec.PublicKey, sort bool) []byte {
var keyBytes bytes.Buffer

if sort {
keys = sortKeys(keys)
}

// We'll create a single buffer and slice into that so the bytes buffer
// doesn't continually need to grow the underlying buffer.
keyAggBuf := make([]byte, 32*len(keys))
keyBytes := bytes.NewBuffer(keyAggBuf[0:0])
for _, key := range keys {
keyBytes.Write(schnorr.SerializePubKey(key))
}
Expand All @@ -76,69 +79,128 @@ func keyHashFingerprint(keys []*btcec.PublicKey, sort bool) []byte {
return h[:]
}

// isSecondKey returns true if the passed public key is the second key in the
// (sorted) keySet passed.
func isSecondKey(keySet []*btcec.PublicKey, targetKey *btcec.PublicKey) bool {
// For this comparison, we want to compare the raw serialized version
// instead of the full pubkey, as it's possible we're dealing with a
// pubkey that _actually_ has an odd y coordinate.
equalBytes := func(a, b *btcec.PublicKey) bool {
return bytes.Equal(
schnorr.SerializePubKey(a),
schnorr.SerializePubKey(b),
)
}

for i := range keySet {
if !equalBytes(keySet[i], keySet[0]) {
return equalBytes(keySet[i], targetKey)
}
}

return false
// keyBytesEqual returns true if two keys are the same from the PoV of BIP
// 340's 32-byte x-only public keys.
func keyBytesEqual(a, b *btcec.PublicKey) bool {
return bytes.Equal(
schnorr.SerializePubKey(a),
schnorr.SerializePubKey(b),
)
}

// aggregationCoefficient computes the key aggregation coefficient for the
// specified target key. The coefficient is computed as:
// * H(tag=KeyAgg coefficient, keyHashFingerprint(pks) || pk)
func aggregationCoefficient(keySet []*btcec.PublicKey,
targetKey *btcec.PublicKey, sort bool) *btcec.ModNScalar {
targetKey *btcec.PublicKey, keysHash []byte,
secondKeyIdx int) *btcec.ModNScalar {

var mu btcec.ModNScalar

// If this is the second key, then this coefficient is just one.
//
// TODO(roasbeef): use intermediate cache to keep track of the second
// key, can just store an index, otherwise this is O(n^2)
if isSecondKey(keySet, targetKey) {
if secondKeyIdx != -1 && keyBytesEqual(keySet[secondKeyIdx], targetKey) {
return mu.SetInt(1)
}

// Otherwise, we'll compute the full finger print hash for this given
// key and then use that to compute the coefficient tagged hash:
// * H(tag=KeyAgg coefficient, keyHashFingerprint(pks, pk) || pk)
var coefficientBytes bytes.Buffer
coefficientBytes.Write(keyHashFingerprint(keySet, sort))
coefficientBytes.Write(schnorr.SerializePubKey(targetKey))
var coefficientBytes [64]byte
copy(coefficientBytes[:], keysHash[:])
copy(coefficientBytes[32:], schnorr.SerializePubKey(targetKey))

muHash := chainhash.TaggedHash(KeyAggTagCoeff, coefficientBytes.Bytes())
muHash := chainhash.TaggedHash(KeyAggTagCoeff, coefficientBytes[:])

mu.SetByteSlice(muHash[:])

return &mu
}

// TODO(roasbeef): make proper IsEven func
// secondUniqueKeyIndex returns the index of the second unique key. If all keys
// are the same, then a value of -1 is returned.
func secondUniqueKeyIndex(keySet []*btcec.PublicKey) int {
// Find the first key that isn't the same as the very first key (second
// unique key).
for i := range keySet {
if !keyBytesEqual(keySet[i], keySet[0]) {
return i
}
}

// A value of negative one is used to indicate that all the keys in the
// sign set are actually equal, which in practice actually makes musig2
// useless, but we need a value to distinguish this case.
return -1
}

// KeyAggOption is a functional option argument that allows callers to specify
// more or less information that has been pre-computed to the main routine.
type KeyAggOption func(*keyAggOption)

// keyAggOption houses the set of functional options that modify key
// aggregation.
type keyAggOption struct {
// keyHash is the output of keyHashFingerprint for a given set of keys.
keyHash []byte

// uniqueKeyIndex is the pre-computed index of the second unique key.
uniqueKeyIndex *int
}

// WithKeysHash allows key aggregation to be optimize, by allowing the caller
// to specify the hash of all the keys.
func WithKeysHash(keyHash []byte) KeyAggOption {
return func(o *keyAggOption) {
o.keyHash = keyHash
}
}

// WithUniqueKeyIndex allows the caller to specify the index of the second
// unique key.
func WithUniqueKeyIndex(idx int) KeyAggOption {
return func(o *keyAggOption) {
i := idx
o.uniqueKeyIndex = &i
}
}

// defaultKeyAggOptions returns the set of default arguments for key
// aggregation.
func defaultKeyAggOptions() *keyAggOption {
return &keyAggOption{}
}

// AggregateKeys takes a list of possibly unsorted keys and returns a single
// aggregated key as specified by the musig2 key aggregation algorithm.
func AggregateKeys(keys []*btcec.PublicKey, sort bool) *btcec.PublicKey {
// aggregated key as specified by the musig2 key aggregation algorithm. A nil
// value can be passed for keyHash, which causes this function to re-derive it.
func AggregateKeys(keys []*btcec.PublicKey, sort bool,
keyOpts ...KeyAggOption) *btcec.PublicKey {

// First, parse the set of optional signing options.
opts := defaultKeyAggOptions()
for _, option := range keyOpts {
option(opts)
}

// Sort the set of public key so we know we're working with them in
// sorted order for all the routines below.
if sort {
keys = sortKeys(keys)
}

// The caller may provide the hash of all the keys as an optimization
// during signing, as it already needs to be computed.
if opts.keyHash == nil {
opts.keyHash = keyHashFingerprint(keys, sort)
}

// A caller may also specify the unique key index themselves so we
// don't need to re-compute it.
if opts.uniqueKeyIndex == nil {
idx := secondUniqueKeyIndex(keys)
opts.uniqueKeyIndex = &idx
}

// For each key, we'll compute the intermediate blinded key: a_i*P_i,
// where a_i is the aggregation coefficient for that key, and P_i is
// the key itself, then accumulate that (addition) into the main final
Expand All @@ -153,7 +215,9 @@ func AggregateKeys(keys []*btcec.PublicKey, sort bool) *btcec.PublicKey {
// Compute the aggregation coefficient for the key, then
// multiply it by the key itself: P_i' = a_i*P_i.
var tweakedKeyJ btcec.JacobianPoint
a := aggregationCoefficient(keys, key, sort)
a := aggregationCoefficient(
keys, key, opts.keyHash, *opts.uniqueKeyIndex,
)
btcec.ScalarMultNonConst(a, &keyJ, &tweakedKeyJ)

// Finally accumulate this into the final key in an incremental
Expand Down
10 changes: 8 additions & 2 deletions btcec/schnorr/musig2/musig2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ func TestMuSig2KeyAggTestVectors(t *testing.T) {
keys = append(keys, pub)
}

combinedKey := AggregateKeys(keys, false)
uniqueKeyIndex := secondUniqueKeyIndex(keys)
combinedKey := AggregateKeys(
keys, false, WithUniqueKeyIndex(uniqueKeyIndex),
)
combinedKeyBytes := schnorr.SerializePubKey(combinedKey)
if !bytes.Equal(combinedKeyBytes, testCase.expectedKey) {
t.Fatalf("case: #%v, invalid aggregation: "+
Expand Down Expand Up @@ -248,7 +251,10 @@ func (s signerSet) pubNonces() [][PubNonceSize]byte {
}

func (s signerSet) combinedKey() *btcec.PublicKey {
return AggregateKeys(s.keys(), false)
uniqueKeyIndex := secondUniqueKeyIndex(s.keys())
return AggregateKeys(
s.keys(), false, WithUniqueKeyIndex(uniqueKeyIndex),
)
}

// TestMuSigMultiParty tests that for a given set of 100 signers, we're able to
Expand Down
27 changes: 21 additions & 6 deletions btcec/schnorr/musig2/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,17 @@ func Sign(secNonce [SecNonceSize]byte, privKey *btcec.PrivateKey,
return nil, err
}

// Compute the hash of all the keys here as we'll need it do aggregrate
// the keys and also at the final step of signing.
keysHash := keyHashFingerprint(pubKeys, opts.sortKeys)

// Next we'll construct the aggregated public key based on the set of
// signers.
combinedKey := AggregateKeys(pubKeys, opts.sortKeys)
uniqueKeyIndex := secondUniqueKeyIndex(pubKeys)
combinedKey := AggregateKeys(
pubKeys, opts.sortKeys, WithKeysHash(keysHash),
WithUniqueKeyIndex(uniqueKeyIndex),
)

// Next we'll compute the value b, that blinds our second public
// nonce:
Expand Down Expand Up @@ -220,7 +228,7 @@ func Sign(secNonce [SecNonceSize]byte, privKey *btcec.PrivateKey,

// Next, we'll compute mu, our aggregation coefficient for the key that
// we're signing with.
mu := aggregationCoefficient(pubKeys, pubKey, opts.sortKeys)
mu := aggregationCoefficient(pubKeys, pubKey, keysHash, uniqueKeyIndex)

// With mu constructed, we can finally generate our partial signature
// as: s = (k1_1 + b*k_2 + e*mu*d) mod n.
Expand Down Expand Up @@ -291,9 +299,18 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,
return err
}

// Compute the hash of all the keys here as we'll need it do aggregrate
// the keys and also at the final step of verification.
keysHash := keyHashFingerprint(keySet, opts.sortKeys)

uniqueKeyIndex := secondUniqueKeyIndex(keySet)

// Next we'll construct the aggregated public key based on the set of
// signers.
combinedKey := AggregateKeys(keySet, opts.sortKeys)
combinedKey := AggregateKeys(
keySet, opts.sortKeys,
WithKeysHash(keysHash), WithUniqueKeyIndex(uniqueKeyIndex),
)

// Next we'll compute the value b, that blinds our second public
// nonce:
Expand Down Expand Up @@ -345,8 +362,6 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,

// If the combined nonce used in the challenge hash has an odd y
// coordinate, then we'll negate our final public nonce.
//
// TODO(roasbeef): make into func
if nonce.Y.IsOdd() {
pubNonceJ.ToAffine()
pubNonceJ.Y.Negate(1)
Expand Down Expand Up @@ -375,7 +390,7 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,

// Next, we'll compute mu, our aggregation coefficient for the key that
// we're signing with.
mu := aggregationCoefficient(keySet, signingKey, opts.sortKeys)
mu := aggregationCoefficient(keySet, signingKey, keysHash, uniqueKeyIndex)

// If the combined key has an odd y coordinate, then we'll negate the
// signer key.
Expand Down

0 comments on commit 1186a89

Please sign in to comment.