diff --git a/internal/doc.go b/internal/doc.go deleted file mode 100644 index 2cc9678..0000000 --- a/internal/doc.go +++ /dev/null @@ -1,10 +0,0 @@ -/* -Package internal contains slightly modified versions of these structures from the -NebolousLabs merkletree implementation as well as some internally used abstractions. - -The only diff to the NebolousLabs types: They take in a TreeHasher instead of a hash.Hash. - -This is an internal package s.t. there types can't be exposed to the publicly visible API. -see: https://dave.cheney.net/2019/10/06/use-internal-packages-to-reduce-your-public-api-surface -*/ -package internal diff --git a/internal/subtree_hasher.go b/internal/subtree_hasher.go deleted file mode 100644 index 6d080af..0000000 --- a/internal/subtree_hasher.go +++ /dev/null @@ -1,52 +0,0 @@ -package internal - -import ( - "io" - - "github.com/celestiaorg/merkletree" -) - -type TreeHasher interface { - merkletree.TreeHasher - Size() int -} - -// CachedSubtreeHasher implements SubtreeHasher using a set of precomputed -// leaf hashes. -type CachedSubtreeHasher struct { - leafHashes [][]byte - TreeHasher -} - -// NextSubtreeRoot implements SubtreeHasher. -func (csh *CachedSubtreeHasher) NextSubtreeRoot(subtreeSize int) ([]byte, error) { - if len(csh.leafHashes) == 0 { - return nil, io.EOF - } - tree := merkletree.NewFromTreehasher(csh.TreeHasher) - for i := 0; i < subtreeSize && len(csh.leafHashes) > 0; i++ { - if err := tree.PushSubTree(0, csh.leafHashes[0]); err != nil { - return nil, err - } - csh.leafHashes = csh.leafHashes[1:] - } - return tree.Root(), nil -} - -// Skip implements SubtreeHasher. -func (csh *CachedSubtreeHasher) Skip(n int) error { - if n > len(csh.leafHashes) { - return io.ErrUnexpectedEOF - } - csh.leafHashes = csh.leafHashes[n:] - return nil -} - -// NewCachedSubtreeHasher creates a CachedSubtreeHasher using the specified -// leaf hashes and hash function. -func NewCachedSubtreeHasher(leafHashes [][]byte, h TreeHasher) *CachedSubtreeHasher { - return &CachedSubtreeHasher{ - leafHashes: leafHashes, - TreeHasher: h, - } -} diff --git a/nmt.go b/nmt.go index e304a65..a96b648 100644 --- a/nmt.go +++ b/nmt.go @@ -8,12 +8,11 @@ import ( "hash" "math/bits" - "github.com/celestiaorg/merkletree" - "github.com/celestiaorg/nmt/internal" "github.com/celestiaorg/nmt/namespace" ) var ( + ErrInvalidRange = errors.New("invalid proof range") ErrMismatchedNamespaceSize = errors.New("mismatching namespace sizes") ErrInvalidPushOrder = errors.New("pushed data has to be lexicographically ordered by namespace IDs") noOp = func(hash []byte, children ...[]byte) {} @@ -127,12 +126,11 @@ func (n NamespacedMerkleTree) Prove(index int) (Proof, error) { func (n NamespacedMerkleTree) ProveRange(start, end int) (Proof, error) { isMaxNsIgnored := n.treeHasher.IsMaxNamespaceIDIgnored() n.computeLeafHashesIfNecessary() - subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher) // TODO: store nodes and re-use the hashes instead recomputing parts of the tree here - proof, err := merkletree.BuildRangeProof(start, end, subTreeHasher) - if err != nil { - return NewEmptyRangeProof(isMaxNsIgnored), err + if start < 0 || start >= end || end > len(n.leafHashes) { + return NewEmptyRangeProof(isMaxNsIgnored), ErrInvalidRange } + proof := n.buildRangeProof(start, end) return NewInclusionProof(start, end, proof, isMaxNsIgnored), nil } @@ -171,20 +169,7 @@ func (n NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) { // the range it would be in (to generate a proof of absence and to return // the corresponding leaf hashes). n.computeLeafHashesIfNecessary() - subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher) - var err error - proof, err := merkletree.BuildRangeProof(proofStart, proofEnd, subTreeHasher) - if err != nil { - // This should never happen. - // TODO would be good to back this by more tests and fuzzing. - return Proof{}, fmt.Errorf( - "unexpected err: %w on nID: %v, range: [%v, %v)", - err, - nID, - proofStart, - proofEnd, - ) - } + proof := n.buildRangeProof(proofStart, proofEnd) if found { return NewInclusionProof(proofStart, proofEnd, proof, isMaxNsIgnored), nil @@ -192,6 +177,58 @@ func (n NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) { return NewAbsenceProof(proofStart, proofEnd, proof, n.leafHashes[proofStart], isMaxNsIgnored), nil } +func (n NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) [][]byte { + proof := [][]byte{} + var recurse func(start, end int, includeNode bool) []byte + recurse = func(start, end int, includeNode bool) []byte { + if start >= len(n.leafHashes) { + return nil + } + + // reached a leaf + if end-start == 1 { + leafHash := n.leafHashes[start] + // if current range does not overlap with proof range, add a node to proofs + if (start < proofStart || start >= proofEnd) && includeNode { + proof = append(proof, leafHash) + } + return leafHash + } + + // recursively get left and right subtree + newIncludeNode := includeNode + if (end <= proofStart || start >= proofEnd) && includeNode { + newIncludeNode = false + } + + k := getSplitPoint(end - start) + left := recurse(start, start+k, newIncludeNode) + right := recurse(start+k, end, newIncludeNode) + + // only right leaf/subtree can be non-existent + var hash []byte + if right == nil { + hash = left + } else { + hash = n.treeHasher.HashNode(left, right) + } + + // highest node in subtree that lies outside proof range + if includeNode && !newIncludeNode { + proof = append(proof, hash) + } + + return hash + } + + fullTreeSize := getSplitPoint(len(n.leafHashes)) * 2 + if fullTreeSize < 1 { + fullTreeSize = 1 + } + recurse(0, fullTreeSize, true) + return proof +} + // Get returns leaves for the given namespace.ID. func (n NamespacedMerkleTree) Get(nID namespace.ID) [][]byte { _, start, end := n.foundInRange(nID) diff --git a/nmt_test.go b/nmt_test.go index 7c32e1c..ba8a6dd 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -509,15 +509,14 @@ func TestNodeVisitor(t *testing.T) { func TestNamespacedMerkleTree_ProveErrors(t *testing.T) { tests := []struct { - name string - nidLen int - index int - pushData []namespaceDataPair - wantErr bool - wantPanic bool + name string + nidLen int + index int + pushData []namespaceDataPair + wantErr bool }{ - {"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), false, true}, - {"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true, false}, + {"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), true}, + {"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -534,20 +533,10 @@ func TestNamespacedMerkleTree_ProveErrors(t *testing.T) { t.Fatalf("Prove() failed on valid index: %v, err: %v", i, err) } } - if tt.wantPanic { - shouldPanic(t, func() { - _, err := n.Prove(tt.index) - if (err != nil) != tt.wantErr { - t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } else { - _, err := n.Prove(tt.index) - if (err != nil) != tt.wantErr { - t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr) - return - } + _, err := n.Prove(tt.index) + if (err != nil) != tt.wantErr { + t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr) + return } }) } diff --git a/proof_test.go b/proof_test.go index dbdf3be..023bcd1 100644 --- a/proof_test.go +++ b/proof_test.go @@ -3,14 +3,58 @@ package nmt import ( "bytes" "crypto/sha256" + "io" "testing" "github.com/celestiaorg/merkletree" - - "github.com/celestiaorg/nmt/internal" "github.com/celestiaorg/nmt/namespace" ) +type treeHasher interface { + merkletree.TreeHasher + Size() int +} + +// CachedSubtreeHasher implements SubtreeHasher using a set of precomputed +// leaf hashes. +type cachedSubtreeHasher struct { + leafHashes [][]byte + treeHasher +} + +// NextSubtreeRoot implements SubtreeHasher. +func (csh *cachedSubtreeHasher) NextSubtreeRoot(subtreeSize int) ([]byte, error) { + if len(csh.leafHashes) == 0 { + return nil, io.EOF + } + tree := merkletree.NewFromTreehasher(csh.treeHasher) + for i := 0; i < subtreeSize && len(csh.leafHashes) > 0; i++ { + if err := tree.PushSubTree(0, csh.leafHashes[0]); err != nil { + return nil, err + } + csh.leafHashes = csh.leafHashes[1:] + } + return tree.Root(), nil +} + +// Skip implements SubtreeHasher. +func (csh *cachedSubtreeHasher) Skip(n int) error { + if n > len(csh.leafHashes) { + return io.ErrUnexpectedEOF + } + csh.leafHashes = csh.leafHashes[n:] + return nil +} + +// newCachedSubtreeHasher creates a CachedSubtreeHasher using the specified +// leaf hashes and hash function. +func newCachedSubtreeHasher(leafHashes [][]byte, h treeHasher) *cachedSubtreeHasher { + return &cachedSubtreeHasher{ + leafHashes: leafHashes, + treeHasher: h, + } +} + func TestProof_VerifyNamespace_False(t *testing.T) { const testNidLen = 3 @@ -98,7 +142,7 @@ func TestProof_VerifyNamespace_False(t *testing.T) { func rangeProof(t *testing.T, n *NamespacedMerkleTree, start, end int) [][]byte { n.computeLeafHashesIfNecessary() - subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher) + subTreeHasher := newCachedSubtreeHasher(n.leafHashes, n.treeHasher) incompleteRange, err := merkletree.BuildRangeProof(start, end, subTreeHasher) if err != nil { t.Fatalf("Could not create range proof: %v", err)