Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add API TraverseStateChanges to extract state changes from iavl versions #654

Merged
merged 9 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [#586](https://github.com/cosmos/iavl/pull/586) Remove the `RangeProof` and refactor the ics23_proof to use the internal methods.
- [#640](https://github.com/cosmos/iavl/pull/640) commit `NodeDB` batch in `LoadVersionForOverwriting`.
- [#636](https://github.com/cosmos/iavl/pull/636) Speed up rollback method: `LoadVersionForOverwriting`.
- [#654](https://github.com/cosmos/iavl/pull/654) Add API `TraverseStateChanges` to extract state changes from iavl versions.

## 0.19.4 (October 28, 2022)

Expand Down
76 changes: 76 additions & 0 deletions diff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package iavl

import (
"bytes"
"sort"

ibytes "github.com/cosmos/iavl/internal/bytes"
)

// ChangeSet represents the state changes extracted from diffing iavl versions.
type ChangeSet struct {
Pairs []KVPair
}

type KVPair struct {
Delete bool
Key []byte
Value []byte
}

// extractStateChanges extracts the state changes by between two versions of the tree.
// it first traverse the `root` tree to find out the `newKeys` and `sharedNodes`,
// `newKeys` are the keys of the newly added leaf nodes, which represents the inserts and updates,
// `sharedNodes` are the referenced nodes that are created in previous versions,
// then we traverse the `prevRoot` tree to find out the deletion entries, we can skip the subtrees
// marked by the `sharedNodes`.
func (ndb *nodeDB) extractStateChanges(prevVersion int64, prevRoot []byte, root []byte) (*ChangeSet, error) {
curIter, err := NewNodeIterator(root, ndb)
if err != nil {
return nil, err
}

prevIter, err := NewNodeIterator(prevRoot, ndb)
if err != nil {
return nil, err
}

var changeSet []KVPair
sharedNodes := make(map[string]struct{})
newKeys := make(map[string]struct{})
for curIter.Valid() {
node := curIter.GetNode()
shared := node.version <= prevVersion
if shared {
sharedNodes[ibytes.UnsafeBytesToStr(node.hash)] = struct{}{}
} else if node.isLeaf() {
changeSet = append(changeSet, KVPair{Key: node.key, Value: node.value})
newKeys[ibytes.UnsafeBytesToStr(node.key)] = struct{}{}
}
// skip subtree of shared nodes
curIter.Next(shared)
}
if err := curIter.Error(); err != nil {
return nil, err
}

for prevIter.Valid() {
node := prevIter.GetNode()
_, shared := sharedNodes[ibytes.UnsafeBytesToStr(node.hash)]
if !shared && node.isLeaf() {
_, updated := newKeys[ibytes.UnsafeBytesToStr(node.key)]
if !updated {
changeSet = append(changeSet, KVPair{Delete: true, Key: node.key})
}
}
prevIter.Next(shared)
}
if err := prevIter.Error(); err != nil {
return nil, err
}

sort.Slice(changeSet, func(i, j int) bool {
return bytes.Compare(changeSet[i].Key, changeSet[j].Key) == -1
})
return &ChangeSet{Pairs: changeSet}, nil
}
111 changes: 111 additions & 0 deletions diff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package iavl

import (
"encoding/binary"
"fmt"
"math"
"math/rand"
"sort"
"testing"

db "github.com/cosmos/cosmos-db"
"github.com/stretchr/testify/require"
)

// TestDiffRoundTrip generate random change sets, build an iavl tree versions,
// then extract state changes from the versions and compare with the original change sets.
func TestDiffRoundTrip(t *testing.T) {
changeSets := genChangeSets(rand.New(rand.NewSource(0)), 300)

// apply changeSets to tree
db := db.NewMemDB()
tree, err := NewMutableTree(db, 0, true)
require.NoError(t, err)
for _, cs := range changeSets {
for _, pair := range cs.Pairs {
if pair.Delete {
_, removed, err := tree.Remove(pair.Key)
require.True(t, removed)
require.NoError(t, err)
} else {
_, err := tree.Set(pair.Key, pair.Value)
require.NoError(t, err)
}
}
_, _, err := tree.SaveVersion()
require.NoError(t, err)
}

// extract change sets from db
var extractChangeSets []ChangeSet
tree2 := NewImmutableTree(db, 0, true)
err = tree2.TraverseStateChanges(0, math.MaxInt64, func(version int64, changeSet *ChangeSet) error {
extractChangeSets = append(extractChangeSets, *changeSet)
return nil
})
require.NoError(t, err)
require.Equal(t, changeSets, extractChangeSets)
}

func genChangeSets(r *rand.Rand, n int) []ChangeSet {
var changeSets []ChangeSet

for i := 0; i < n; i++ {
items := make(map[string]KVPair)
start, count, step := r.Int63n(1000), r.Int63n(1000), r.Int63n(10)
for i := start; i < start+count*step; i += step {
value := make([]byte, 8)
binary.LittleEndian.PutUint64(value, uint64(i))

key := fmt.Sprintf("test-%d", i)
items[key] = KVPair{
Key: []byte(key),
Value: value,
}
}
if len(changeSets) > 0 {
// pick some random keys to delete from the last version
lastChangeSet := changeSets[len(changeSets)-1]
count = r.Int63n(10)
for _, pair := range lastChangeSet.Pairs {
if count <= 0 {
break
}
if pair.Delete {
continue
}
items[string(pair.Key)] = KVPair{
Key: pair.Key,
Delete: true,
}
count--
}

// Special case, set to identical value
if len(lastChangeSet.Pairs) > 0 {
i := r.Int63n(int64(len(lastChangeSet.Pairs)))
pair := lastChangeSet.Pairs[i]
if !pair.Delete {
items[string(pair.Key)] = KVPair{
Key: pair.Key,
Value: pair.Value,
}
}
}
}

var keys []string
for key := range items {
keys = append(keys, key)
}
sort.Strings(keys)

var cs ChangeSet
for _, key := range keys {
cs.Pairs = append(cs.Pairs, items[key])
}

changeSets = append(changeSets, cs)
}
return changeSets
}
6 changes: 6 additions & 0 deletions immutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,9 @@ func (t *ImmutableTree) nodeSize() int {
})
return size
}

// TraverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it.
// endVersion is exclusive.
func (t *ImmutableTree) TraverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error {
return t.ndb.traverseStateChanges(startVersion, endVersion, fn)
}
76 changes: 76 additions & 0 deletions iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,79 @@ func (iter *Iterator) Error() error {
func (iter *Iterator) IsFast() bool {
return false
}

// NodeIterator is an iterator for nodeDB to traverse a tree in depth-first, preorder manner.
type NodeIterator struct {
nodesToVisit []*Node
ndb *nodeDB
err error
}

// NewNodeIterator returns a new NodeIterator to traverse the tree of the root node.
func NewNodeIterator(root []byte, ndb *nodeDB) (*NodeIterator, error) {
if len(root) == 0 {
return &NodeIterator{
nodesToVisit: []*Node{},
ndb: ndb,
}, nil
}

node, err := ndb.GetNode(root)
if err != nil {
return nil, err
}

return &NodeIterator{
nodesToVisit: []*Node{node},
ndb: ndb,
}, nil
}

// GetNode returns the current visiting node.
func (iter *NodeIterator) GetNode() *Node {
return iter.nodesToVisit[len(iter.nodesToVisit)-1]
}

// Valid checks if the validator is valid.
func (iter *NodeIterator) Valid() bool {
if iter.err != nil {
return false
}
return len(iter.nodesToVisit) > 0
}

// Error returns an error if any errors.
func (iter *NodeIterator) Error() error {
return iter.err
}

// Next moves forward the traversal.
// if isSkipped is true, the subtree under the current node is skipped.
func (iter *NodeIterator) Next(isSkipped bool) {
if !iter.Valid() {
return
}
node := iter.GetNode()
iter.nodesToVisit = iter.nodesToVisit[:len(iter.nodesToVisit)-1]

if isSkipped {
return
}

if node.isLeaf() {
return
}

leftNode, err := iter.ndb.GetNode(node.leftHash)
if err != nil {
iter.err = err
return
}
iter.nodesToVisit = append(iter.nodesToVisit, leftNode)
rightNode, err := iter.ndb.GetNode(node.rightHash)
if err != nil {
iter.err = err
return
}
iter.nodesToVisit = append(iter.nodesToVisit, rightNode)
}
27 changes: 27 additions & 0 deletions nodedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,33 @@ func (ndb *nodeDB) traverseNodes(fn func(hash []byte, node *Node) error) error {
return nil
}

// traverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it.
// endVersion is exclusive, set to `math.MaxInt64` to cover the latest version.
func (ndb *nodeDB) traverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error {
tac0turtle marked this conversation as resolved.
Show resolved Hide resolved
predecessor, err := ndb.getPreviousVersion(startVersion)
if err != nil {
return err
}
prevRoot, err := ndb.getRoot(predecessor)
if err != nil {
return err
}
return ndb.traverseRange(rootKeyFormat.Key(startVersion), rootKeyFormat.Key(endVersion), func(k, hash []byte) error {
var version int64
rootKeyFormat.Scan(k, &version)
changeSet, err := ndb.extractStateChanges(predecessor, prevRoot, hash)
if err != nil {
return err
}
if err := fn(version, changeSet); err != nil {
return err
}
predecessor = version
prevRoot = hash
return nil
})
}

func (ndb *nodeDB) String() (string, error) {
buf := bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(buf)
Expand Down