Skip to content

Commit

Permalink
Optimize diff algorithm with insightes from cosmos#646
Browse files Browse the repository at this point in the history
fix test
  • Loading branch information
yihuang committed Jan 11, 2023
1 parent 5d6c399 commit fee0d57
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 33 deletions.
115 changes: 88 additions & 27 deletions diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package iavl

import (
"bytes"
"sort"

ibytes "github.com/cosmos/iavl/internal/bytes"
)

// ChangeSet represents the state changes extracted from diffing iavl versions.
Expand Down Expand Up @@ -36,41 +33,105 @@ func (ndb *nodeDB) extractStateChanges(prevVersion int64, prevRoot []byte, root
}

var changeSet []KVPair
sharedNodes := make(map[string]struct{})
newKeys := make(map[string]struct{})
for curIter.Valid() {
node := curIter.GetNode()
shared := node.version <= prevVersion
if shared {
sharedNodes[ibytes.UnsafeBytesToStr(node.hash)] = struct{}{}
} else if node.isLeaf() {
changeSet = append(changeSet, KVPair{Key: node.key, Value: node.value})
newKeys[ibytes.UnsafeBytesToStr(node.key)] = struct{}{}

var (
// current shared node between two versions
sharedNode *Node
// record all newly added leaf nodes in newer version, it represents all updates and insertions.
newLeafNodes []*Node
// orphaned leaf nodes in previous version, which represents all deletions and updates.
// both `newLeafNodes` and `orphanedLeafNodes` are ordered by key.
orphanedLeafNodes []*Node
)

advanceSharedNode := func() {
// Forward `curIter` until the next `sharedNode`.
// `sharedNode` will be `nil` if the new version is exhausted.
sharedNode = nil
for curIter.Valid() {
node := curIter.GetNode()
shared := node.version <= prevVersion
curIter.Next(shared)
if shared {
sharedNode = node
break
} else if node.isLeaf() {
newLeafNodes = append(newLeafNodes, node)
}
}
// skip subtree of shared nodes
curIter.Next(shared)
}
if err := curIter.Error(); err != nil {
return nil, err
}
advanceSharedNode()

// Traverse `prevIter` to find orphaned nodes in the previous version.
for prevIter.Valid() {
node := prevIter.GetNode()
_, shared := sharedNodes[ibytes.UnsafeBytesToStr(node.hash)]
if !shared && node.isLeaf() {
_, updated := newKeys[ibytes.UnsafeBytesToStr(node.key)]
if !updated {
changeSet = append(changeSet, KVPair{Delete: true, Key: node.key})
}
}
shared := sharedNode != nil && (node == sharedNode || bytes.Equal(node.hash, sharedNode.hash))
// skip sub-tree of shared nodes
prevIter.Next(shared)
if shared {
advanceSharedNode()
} else if node.isLeaf() {
orphanedLeafNodes = append(orphanedLeafNodes, node)
}
}

if err := curIter.Error(); err != nil {
return nil, err
}
if err := prevIter.Error(); err != nil {
return nil, err
}

sort.Slice(changeSet, func(i, j int) bool {
return bytes.Compare(changeSet[i].Key, changeSet[j].Key) == -1
findDeletedNodes(orphanedLeafNodes, newLeafNodes, func(node *Node, deleted bool) {
pair := KVPair{Key: node.key, Delete: deleted}
if !deleted {
pair.Value = node.value
}
changeSet = append(changeSet, pair)
})

return &ChangeSet{Pairs: changeSet}, nil
}

// findDeletedNodes find out the deleted keys in `nodes1`.
// Invariant: both `nodes1` and `nodes2` are ordered by key.
func findDeletedNodes(nodes1 []*Node, nodes2 []*Node, cb func(node *Node, deleted bool)) {
// find out the deletions by diff two list of ordered nodes
var i1, i2 int
for {
if i1 >= len(nodes1) {
// insertions
for ; i2 < len(nodes2); i2++ {
cb(nodes2[i2], false)
}
break
}

if i2 >= len(nodes2) {
// deletions
for ; i1 < len(nodes1); i1++ {
cb(nodes1[i1], true)
}
break
}

cur1 := nodes1[i1]
cur2 := nodes2[i2]

switch bytes.Compare(cur1.key, cur2.key) {
case -1:
// deletion
cb(cur1, true)
i1++
case 1:
// insertion
cb(cur2, false)
i2++
default:
// update
cb(cur2, false)
i1++
i2++
}
}
}
8 changes: 2 additions & 6 deletions iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,7 @@ func (iter *NodeIterator) GetNode() *Node {

// Valid checks if the validator is valid.
func (iter *NodeIterator) Valid() bool {
if iter.err != nil {
return false
}
return len(iter.nodesToVisit) > 0
return iter.err == nil && len(iter.nodesToVisit) > 0
}

// Error returns an error if any errors.
Expand Down Expand Up @@ -327,11 +324,10 @@ func (iter *NodeIterator) Next(isSkipped bool) {
iter.err = err
return
}
iter.nodesToVisit = append(iter.nodesToVisit, leftNode)
rightNode, err := iter.ndb.GetNode(node.rightHash)
if err != nil {
iter.err = err
return
}
iter.nodesToVisit = append(iter.nodesToVisit, rightNode)
iter.nodesToVisit = append(iter.nodesToVisit, rightNode, leftNode)
}

0 comments on commit fee0d57

Please sign in to comment.