Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency on nebulous in proof generation #68

Merged
merged 9 commits into from
Oct 12, 2022
52 changes: 0 additions & 52 deletions internal/subtree_hasher.go

This file was deleted.

77 changes: 57 additions & 20 deletions nmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import (
"hash"
"math/bits"

"github.com/celestiaorg/merkletree"
"github.com/celestiaorg/nmt/internal"
"github.com/celestiaorg/nmt/namespace"
)

var (
ErrInvalidRange = errors.New("invalid proof range")
ErrMismatchedNamespaceSize = errors.New("mismatching namespace sizes")
ErrInvalidPushOrder = errors.New("pushed data has to be lexicographically ordered by namespace IDs")
noOp = func(hash []byte, children ...[]byte) {}
Expand Down Expand Up @@ -127,12 +126,11 @@ func (n NamespacedMerkleTree) Prove(index int) (Proof, error) {
func (n NamespacedMerkleTree) ProveRange(start, end int) (Proof, error) {
isMaxNsIgnored := n.treeHasher.IsMaxNamespaceIDIgnored()
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
// TODO: store nodes and re-use the hashes instead recomputing parts of the tree here
proof, err := merkletree.BuildRangeProof(start, end, subTreeHasher)
if err != nil {
return NewEmptyRangeProof(isMaxNsIgnored), err
if start < 0 || start > end || start == end || end > len(n.leafHashes) {
rahulghangas marked this conversation as resolved.
Show resolved Hide resolved
return NewEmptyRangeProof(isMaxNsIgnored), ErrInvalidRange
}
proof := n.buildRangeProof(start, end)

return NewInclusionProof(start, end, proof, isMaxNsIgnored), nil
}
Expand Down Expand Up @@ -171,27 +169,66 @@ func (n NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) {
// the range it would be in (to generate a proof of absence and to return
// the corresponding leaf hashes).
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
var err error
proof, err := merkletree.BuildRangeProof(proofStart, proofEnd, subTreeHasher)
if err != nil {
// This should never happen.
// TODO would be good to back this by more tests and fuzzing.
return Proof{}, fmt.Errorf(
"unexpected err: %w on nID: %v, range: [%v, %v)",
err,
nID,
proofStart,
proofEnd,
)
}
proof := n.buildRangeProof(proofStart, proofEnd)

if found {
return NewInclusionProof(proofStart, proofEnd, proof, isMaxNsIgnored), nil
}
return NewAbsenceProof(proofStart, proofEnd, proof, n.leafHashes[proofStart], isMaxNsIgnored), nil
}

func (n NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) [][]byte {
proof := [][]byte{}
var recurse func(start, end int, includeNode bool) []byte
recurse = func(start, end int, includeNode bool) []byte {
if start >= len(n.leafHashes) {
return nil
}

// reached a leaf
if end-start == 1 {
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
leafHash := n.leafHashes[start]
// if current range does not overlap with proof range, add a node to proofs
if (start < proofStart || start >= proofEnd) && includeNode {
proof = append(proof, leafHash)
}
return leafHash
}

// recursively get left and right subtree
newIncludeNode := includeNode
if (end <= proofStart || start >= proofEnd) && includeNode {
newIncludeNode = false
}

k := getSplitPoint(end - start)
left := recurse(start, start+k, newIncludeNode)
right := recurse(start+k, end, newIncludeNode)

// only right leaf/subtree can be non-existent
var hash []byte
if right == nil {
hash = left
} else {
hash = n.treeHasher.HashNode(left, right)
}

// highest node in subtree that lies outside proof range
if includeNode && !newIncludeNode {
proof = append(proof, hash)
}

return hash
}

fullTreeSize := getSplitPoint(len(n.leafHashes)) * 2
if fullTreeSize < 1 {
fullTreeSize = 1
}
recurse(0, fullTreeSize, true)
return proof
}

// Get returns leaves for the given namespace.ID.
func (n NamespacedMerkleTree) Get(nID namespace.ID) [][]byte {
_, start, end := n.foundInRange(nID)
Expand Down
33 changes: 11 additions & 22 deletions nmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,14 @@ func TestNodeVisitor(t *testing.T) {

func TestNamespacedMerkleTree_ProveErrors(t *testing.T) {
tests := []struct {
name string
nidLen int
index int
pushData []namespaceDataPair
wantErr bool
wantPanic bool
name string
nidLen int
index int
pushData []namespaceDataPair
wantErr bool
}{
{"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), false, true},
{"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true, false},
{"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), true},
{"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -534,20 +533,10 @@ func TestNamespacedMerkleTree_ProveErrors(t *testing.T) {
t.Fatalf("Prove() failed on valid index: %v, err: %v", i, err)
}
}
if tt.wantPanic {
shouldPanic(t, func() {
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
} else {
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
Expand Down
50 changes: 47 additions & 3 deletions proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,58 @@ package nmt
import (
"bytes"
"crypto/sha256"
"io"
"testing"

"github.com/celestiaorg/merkletree"
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved

"github.com/celestiaorg/nmt/internal"
"github.com/celestiaorg/nmt/namespace"
)

type treeHasher interface {
merkletree.TreeHasher
Size() int
}

// CachedSubtreeHasher implements SubtreeHasher using a set of precomputed
// leaf hashes.
type cachedSubtreeHasher struct {
leafHashes [][]byte
treeHasher
}

// NextSubtreeRoot implements SubtreeHasher.
func (csh *cachedSubtreeHasher) NextSubtreeRoot(subtreeSize int) ([]byte, error) {
if len(csh.leafHashes) == 0 {
return nil, io.EOF
}
tree := merkletree.NewFromTreehasher(csh.treeHasher)
for i := 0; i < subtreeSize && len(csh.leafHashes) > 0; i++ {
if err := tree.PushSubTree(0, csh.leafHashes[0]); err != nil {
return nil, err
}
csh.leafHashes = csh.leafHashes[1:]
}
return tree.Root(), nil
}

// Skip implements SubtreeHasher.
func (csh *cachedSubtreeHasher) Skip(n int) error {
if n > len(csh.leafHashes) {
return io.ErrUnexpectedEOF
}
csh.leafHashes = csh.leafHashes[n:]
return nil
}

// newCachedSubtreeHasher creates a CachedSubtreeHasher using the specified
// leaf hashes and hash function.
func newCachedSubtreeHasher(leafHashes [][]byte, h treeHasher) *cachedSubtreeHasher {
return &cachedSubtreeHasher{
leafHashes: leafHashes,
treeHasher: h,
}
}

func TestProof_VerifyNamespace_False(t *testing.T) {
const testNidLen = 3

Expand Down Expand Up @@ -98,7 +142,7 @@ func TestProof_VerifyNamespace_False(t *testing.T) {

func rangeProof(t *testing.T, n *NamespacedMerkleTree, start, end int) [][]byte {
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
subTreeHasher := newCachedSubtreeHasher(n.leafHashes, n.treeHasher)
incompleteRange, err := merkletree.BuildRangeProof(start, end, subTreeHasher)
if err != nil {
t.Fatalf("Could not create range proof: %v", err)
Expand Down