diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cd91992d..c0fad7bf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - [#640](https://github.com/cosmos/iavl/pull/640) commit `NodeDB` batch in `LoadVersionForOverwriting`. - [#636](https://github.com/cosmos/iavl/pull/636) Speed up rollback method: `LoadVersionForOverwriting`. +- [#654](https://github.com/cosmos/iavl/pull/654) Add API `TraverseStateChanges` to extract state changes from iavl versions. ## 0.19.4 (October 28, 2022) diff --git a/diff.go b/diff.go new file mode 100644 index 000000000..b16cd32fa --- /dev/null +++ b/diff.go @@ -0,0 +1,76 @@ +package iavl + +import ( + "bytes" + "sort" + + ibytes "github.com/cosmos/iavl/internal/bytes" +) + +// ChangeSet represents the state changes extracted from diffing iavl versions. +type ChangeSet struct { + Pairs []KVPair +} + +type KVPair struct { + Delete bool + Key []byte + Value []byte +} + +// extractStateChanges extracts the state changes by between two versions of the tree. +// it first traverse the `root` tree to find out the `newKeys` and `sharedNodes`, +// `newKeys` are the keys of the newly added leaf nodes, which represents the inserts and updates, +// `sharedNodes` are the referenced nodes that are created in previous versions, +// then we traverse the `prevRoot` tree to find out the deletion entries, we can skip the subtrees +// marked by the `sharedNodes`. +func (ndb *nodeDB) extractStateChanges(prevVersion int64, prevRoot []byte, root []byte) (*ChangeSet, error) { + curIter, err := NewNodeIterator(root, ndb) + if err != nil { + return nil, err + } + + prevIter, err := NewNodeIterator(prevRoot, ndb) + if err != nil { + return nil, err + } + + var changeSet []KVPair + sharedNodes := make(map[string]struct{}) + newKeys := make(map[string]struct{}) + for curIter.Valid() { + node := curIter.GetNode() + shared := node.version <= prevVersion + if shared { + sharedNodes[ibytes.UnsafeBytesToStr(node.hash)] = struct{}{} + } else if node.isLeaf() { + changeSet = append(changeSet, KVPair{Key: node.key, Value: node.value}) + newKeys[ibytes.UnsafeBytesToStr(node.key)] = struct{}{} + } + // skip subtree of shared nodes + curIter.Next(shared) + } + if err := curIter.Error(); err != nil { + return nil, err + } + + for prevIter.Valid() { + node := prevIter.GetNode() + _, shared := sharedNodes[ibytes.UnsafeBytesToStr(node.hash)] + if !shared && node.isLeaf() { + _, updated := newKeys[ibytes.UnsafeBytesToStr(node.key)] + if !updated { + changeSet = append(changeSet, KVPair{Delete: true, Key: node.key}) + } + } + prevIter.Next(shared) + } + if err := prevIter.Error(); err != nil { + return nil, err + } + + sort.Slice(changeSet, func(i, j int) bool { + return bytes.Compare(changeSet[i].Key, changeSet[j].Key) == -1 + }) + return &ChangeSet{Pairs: changeSet}, nil +} diff --git a/diff_test.go b/diff_test.go new file mode 100644 index 000000000..c4638d3dc --- /dev/null +++ b/diff_test.go @@ -0,0 +1,111 @@ +package iavl + +import ( + "encoding/binary" + "fmt" + "math" + "math/rand" + "sort" + "testing" + + db "github.com/tendermint/tm-db" + "github.com/stretchr/testify/require" +) + +// TestDiffRoundTrip generate random change sets, build an iavl tree versions, +// then extract state changes from the versions and compare with the original change sets. +func TestDiffRoundTrip(t *testing.T) { + changeSets := genChangeSets(rand.New(rand.NewSource(0)), 300) + + // apply changeSets to tree + db := db.NewMemDB() + tree, err := NewMutableTree(db, 0, true) + require.NoError(t, err) + for _, cs := range changeSets { + for _, pair := range cs.Pairs { + if pair.Delete { + _, removed, err := tree.Remove(pair.Key) + require.True(t, removed) + require.NoError(t, err) + } else { + _, err := tree.Set(pair.Key, pair.Value) + require.NoError(t, err) + } + } + _, _, err := tree.SaveVersion() + require.NoError(t, err) + } + + // extract change sets from db + var extractChangeSets []ChangeSet + tree2 := NewImmutableTree(db, 0, true) + err = tree2.TraverseStateChanges(0, math.MaxInt64, func(version int64, changeSet *ChangeSet) error { + extractChangeSets = append(extractChangeSets, *changeSet) + return nil + }) + require.NoError(t, err) + require.Equal(t, changeSets, extractChangeSets) +} + +func genChangeSets(r *rand.Rand, n int) []ChangeSet { + var changeSets []ChangeSet + + for i := 0; i < n; i++ { + items := make(map[string]KVPair) + start, count, step := r.Int63n(1000), r.Int63n(1000), r.Int63n(10) + for i := start; i < start+count*step; i += step { + value := make([]byte, 8) + binary.LittleEndian.PutUint64(value, uint64(i)) + + key := fmt.Sprintf("test-%d", i) + items[key] = KVPair{ + Key: []byte(key), + Value: value, + } + } + if len(changeSets) > 0 { + // pick some random keys to delete from the last version + lastChangeSet := changeSets[len(changeSets)-1] + count = r.Int63n(10) + for _, pair := range lastChangeSet.Pairs { + if count <= 0 { + break + } + if pair.Delete { + continue + } + items[string(pair.Key)] = KVPair{ + Key: pair.Key, + Delete: true, + } + count-- + } + + // Special case, set to identical value + if len(lastChangeSet.Pairs) > 0 { + i := r.Int63n(int64(len(lastChangeSet.Pairs))) + pair := lastChangeSet.Pairs[i] + if !pair.Delete { + items[string(pair.Key)] = KVPair{ + Key: pair.Key, + Value: pair.Value, + } + } + } + } + + var keys []string + for key := range items { + keys = append(keys, key) + } + sort.Strings(keys) + + var cs ChangeSet + for _, key := range keys { + cs.Pairs = append(cs.Pairs, items[key]) + } + + changeSets = append(changeSets, cs) + } + return changeSets +} diff --git a/immutable_tree.go b/immutable_tree.go index bc857e4ad..118d90b46 100644 --- a/immutable_tree.go +++ b/immutable_tree.go @@ -331,3 +331,9 @@ func (t *ImmutableTree) nodeSize() int { }) return size } + +// TraverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it. +// endVersion is exclusive. +func (t *ImmutableTree) TraverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error { + return t.ndb.traverseStateChanges(startVersion, endVersion, fn) +} diff --git a/iterator.go b/iterator.go index a509ed1ad..384aeedfe 100644 --- a/iterator.go +++ b/iterator.go @@ -259,3 +259,79 @@ func (iter *Iterator) Error() error { func (iter *Iterator) IsFast() bool { return false } + +// NodeIterator is an iterator for nodeDB to traverse a tree in depth-first, preorder manner. +type NodeIterator struct { + nodesToVisit []*Node + ndb *nodeDB + err error +} + +// NewNodeIterator returns a new NodeIterator to traverse the tree of the root node. +func NewNodeIterator(root []byte, ndb *nodeDB) (*NodeIterator, error) { + if len(root) == 0 { + return &NodeIterator{ + nodesToVisit: []*Node{}, + ndb: ndb, + }, nil + } + + node, err := ndb.GetNode(root) + if err != nil { + return nil, err + } + + return &NodeIterator{ + nodesToVisit: []*Node{node}, + ndb: ndb, + }, nil +} + +// GetNode returns the current visiting node. +func (iter *NodeIterator) GetNode() *Node { + return iter.nodesToVisit[len(iter.nodesToVisit)-1] +} + +// Valid checks if the validator is valid. +func (iter *NodeIterator) Valid() bool { + if iter.err != nil { + return false + } + return len(iter.nodesToVisit) > 0 +} + +// Error returns an error if any errors. +func (iter *NodeIterator) Error() error { + return iter.err +} + +// Next moves forward the traversal. +// if isSkipped is true, the subtree under the current node is skipped. +func (iter *NodeIterator) Next(isSkipped bool) { + if !iter.Valid() { + return + } + node := iter.GetNode() + iter.nodesToVisit = iter.nodesToVisit[:len(iter.nodesToVisit)-1] + + if isSkipped { + return + } + + if node.isLeaf() { + return + } + + leftNode, err := iter.ndb.GetNode(node.leftHash) + if err != nil { + iter.err = err + return + } + iter.nodesToVisit = append(iter.nodesToVisit, leftNode) + rightNode, err := iter.ndb.GetNode(node.rightHash) + if err != nil { + iter.err = err + return + } + iter.nodesToVisit = append(iter.nodesToVisit, rightNode) +} diff --git a/nodedb.go b/nodedb.go index 646f953b3..61e7e55c0 100644 --- a/nodedb.go +++ b/nodedb.go @@ -1051,6 +1051,33 @@ func (ndb *nodeDB) traverseNodes(fn func(hash []byte, node *Node) error) error { return nil } +// traverseStateChanges iterate the range of versions, compare each version to it's predecessor to extract the state changes of it. +// endVersion is exclusive, set to `math.MaxInt64` to cover the latest version. +func (ndb *nodeDB) traverseStateChanges(startVersion, endVersion int64, fn func(version int64, changeSet *ChangeSet) error) error { + predecessor, err := ndb.getPreviousVersion(startVersion) + if err != nil { + return err + } + prevRoot, err := ndb.getRoot(predecessor) + if err != nil { + return err + } + return ndb.traverseRange(rootKeyFormat.Key(startVersion), rootKeyFormat.Key(endVersion), func(k, hash []byte) error { + var version int64 + rootKeyFormat.Scan(k, &version) + changeSet, err := ndb.extractStateChanges(predecessor, prevRoot, hash) + if err != nil { + return err + } + if err := fn(version, changeSet); err != nil { + return err + } + predecessor = version + prevRoot = hash + return nil + }) +} + func (ndb *nodeDB) String() (string, error) { buf := bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buf)