From cbee1b3ea1feb1b785b286e71f617603ef26d540 Mon Sep 17 00:00:00 2001 From: yihuang Date: Fri, 16 Dec 2022 10:52:44 +0800 Subject: [PATCH] perf: optimize iteration on nested cache context (#13881) Co-authored-by: Aleksandr Bezobchuk Co-authored-by: Marko Closes https://github.com/cosmos/cosmos-sdk/issues/10310 --- CHANGELOG.md | 1 + go.mod | 1 + go.sum | 2 + simapp/go.mod | 1 + simapp/go.sum | 2 + store/cachekv/benchmark_test.go | 161 ++++++++++++++ store/cachekv/internal/btree.go | 80 +++++++ store/cachekv/internal/btree_test.go | 202 ++++++++++++++++++ store/cachekv/internal/memiterator.go | 137 ++++++++++++ store/cachekv/{ => internal}/mergeiterator.go | 53 +++-- store/cachekv/memiterator.go | 57 ----- store/cachekv/search_benchmark_test.go | 4 +- store/cachekv/store.go | 22 +- store/cachekv/store_test.go | 38 +++- tests/go.mod | 1 + tests/go.sum | 2 + 16 files changed, 663 insertions(+), 101 deletions(-) create mode 100644 store/cachekv/benchmark_test.go create mode 100644 store/cachekv/internal/btree.go create mode 100644 store/cachekv/internal/btree_test.go create mode 100644 store/cachekv/internal/memiterator.go rename store/cachekv/{ => internal}/mergeiterator.go (86%) delete mode 100644 store/cachekv/memiterator.go diff --git a/CHANGELOG.md b/CHANGELOG.md index a3669eaac020..0b7748d60203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * [#13794](https://github.com/cosmos/cosmos-sdk/pull/13794) `types/module.Manager` now supports the `cosmossdk.io/core/appmodule.AppModule` API via the new `NewManagerFromMap` constructor. * [#14019](https://github.com/cosmos/cosmos-sdk/issues/14019) Remove the interface casting to allow other implementations of a `CommitMultiStore`. +* [#13881](https://github.com/cosmos/cosmos-sdk/pull/13881) Optimize iteration on nested cached KV stores and other operations in general. ### State Machine Breaking diff --git a/go.mod b/go.mod index b7fc1ba4fb46..78bcf92301c0 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/tendermint/go-amino v0.16.0 github.com/tendermint/tendermint v0.37.0-rc2 github.com/tendermint/tm-db v0.6.7 + github.com/tidwall/btree v1.5.2 golang.org/x/crypto v0.4.0 golang.org/x/exp v0.0.0-20221019170559-20944726eadf google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 diff --git a/go.sum b/go.sum index fd4ac931bb90..9417709852de 100644 --- a/go.sum +++ b/go.sum @@ -754,6 +754,8 @@ github.com/tendermint/tendermint v0.37.0-rc2 h1:2n1em+jfbhSv6QnBj8F6KHCpbIzZCB8K github.com/tendermint/tendermint v0.37.0-rc2/go.mod h1:uYQO9DRNPeZROa9X3hJOZpYcVREDC2/HST+EiU5g2+A= github.com/tendermint/tm-db v0.6.7 h1:fE00Cbl0jayAoqlExN6oyQJ7fR/ZtoVOmvPJ//+shu8= github.com/tendermint/tm-db v0.6.7/go.mod h1:byQDzFkZV1syXr/ReXS808NxA2xvyuuVgXOJ/088L6I= +github.com/tidwall/btree v1.5.2 h1:5eA83Gfki799V3d3bJo9sWk+yL2LRoTEah3O/SA6/8w= +github.com/tidwall/btree v1.5.2/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= diff --git a/simapp/go.mod b/simapp/go.mod index 84a509e39803..f19e83af3ee9 100644 --- a/simapp/go.mod +++ b/simapp/go.mod @@ -136,6 +136,7 @@ require ( github.com/tendermint/btcd v0.1.1 // indirect github.com/tendermint/crypto v0.0.0-20191022145703-50d29ede1e15 // indirect github.com/tendermint/go-amino v0.16.0 // indirect + github.com/tidwall/btree v1.5.2 // indirect github.com/ulikunitz/xz v0.5.8 // indirect github.com/zondax/hid v0.9.1 // indirect github.com/zondax/ledger-go v0.14.0 // indirect diff --git a/simapp/go.sum b/simapp/go.sum index 4ad3c85baa8c..f9ae97d4e0d5 100644 --- a/simapp/go.sum +++ b/simapp/go.sum @@ -753,6 +753,8 @@ github.com/tendermint/tendermint v0.37.0-rc2 h1:2n1em+jfbhSv6QnBj8F6KHCpbIzZCB8K github.com/tendermint/tendermint v0.37.0-rc2/go.mod h1:uYQO9DRNPeZROa9X3hJOZpYcVREDC2/HST+EiU5g2+A= github.com/tendermint/tm-db v0.6.7 h1:fE00Cbl0jayAoqlExN6oyQJ7fR/ZtoVOmvPJ//+shu8= github.com/tendermint/tm-db v0.6.7/go.mod h1:byQDzFkZV1syXr/ReXS808NxA2xvyuuVgXOJ/088L6I= +github.com/tidwall/btree v1.5.2 h1:5eA83Gfki799V3d3bJo9sWk+yL2LRoTEah3O/SA6/8w= +github.com/tidwall/btree v1.5.2/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= diff --git a/store/cachekv/benchmark_test.go b/store/cachekv/benchmark_test.go new file mode 100644 index 000000000000..2db62ba5d6c6 --- /dev/null +++ b/store/cachekv/benchmark_test.go @@ -0,0 +1,161 @@ +package cachekv_test + +import ( + fmt "fmt" + "testing" + + "github.com/cosmos/cosmos-sdk/store" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + dbm "github.com/tendermint/tm-db" +) + +func DoBenchmarkDeepContextStack(b *testing.B, depth int) { + begin := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + end := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + key := storetypes.NewKVStoreKey("test") + + db := dbm.NewMemDB() + cms := store.NewCommitMultiStore(db) + cms.MountStoreWithDB(key, storetypes.StoreTypeIAVL, db) + cms.LoadLatestVersion() + ctx := sdk.NewContext(cms, tmproto.Header{}, false, log.NewNopLogger()) + + var stack ContextStack + stack.Reset(ctx) + + for i := 0; i < depth; i++ { + stack.Snapshot() + + store := stack.CurrentContext().KVStore(key) + store.Set(begin, []byte("value")) + } + + store := stack.CurrentContext().KVStore(key) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := store.Iterator(begin, end) + it.Valid() + it.Key() + it.Value() + it.Next() + it.Close() + } +} + +func BenchmarkDeepContextStack1(b *testing.B) { + DoBenchmarkDeepContextStack(b, 1) +} + +func BenchmarkDeepContextStack3(b *testing.B) { + DoBenchmarkDeepContextStack(b, 3) +} +func BenchmarkDeepContextStack10(b *testing.B) { + DoBenchmarkDeepContextStack(b, 10) +} + +func BenchmarkDeepContextStack13(b *testing.B) { + DoBenchmarkDeepContextStack(b, 13) +} + +// cachedContext is a pair of cache context and its corresponding commit method. +// They are obtained from the return value of `context.CacheContext()`. +type cachedContext struct { + ctx sdk.Context + commit func() +} + +// ContextStack manages the initial context and a stack of cached contexts, +// to support the `StateDB.Snapshot` and `StateDB.RevertToSnapshot` methods. +// +// Copied from an old version of ethermint +type ContextStack struct { + // Context of the initial state before transaction execution. + // It's the context used by `StateDB.CommitedState`. + initialCtx sdk.Context + cachedContexts []cachedContext +} + +// CurrentContext returns the top context of cached stack, +// if the stack is empty, returns the initial context. +func (cs *ContextStack) CurrentContext() sdk.Context { + l := len(cs.cachedContexts) + if l == 0 { + return cs.initialCtx + } + return cs.cachedContexts[l-1].ctx +} + +// Reset sets the initial context and clear the cache context stack. +func (cs *ContextStack) Reset(ctx sdk.Context) { + cs.initialCtx = ctx + if len(cs.cachedContexts) > 0 { + cs.cachedContexts = []cachedContext{} + } +} + +// IsEmpty returns true if the cache context stack is empty. +func (cs *ContextStack) IsEmpty() bool { + return len(cs.cachedContexts) == 0 +} + +// Commit commits all the cached contexts from top to bottom in order and clears the stack by setting an empty slice of cache contexts. +func (cs *ContextStack) Commit() { + // commit in order from top to bottom + for i := len(cs.cachedContexts) - 1; i >= 0; i-- { + if cs.cachedContexts[i].commit == nil { + panic(fmt.Sprintf("commit function at index %d should not be nil", i)) + } else { + cs.cachedContexts[i].commit() + } + } + cs.cachedContexts = []cachedContext{} +} + +// CommitToRevision commit the cache after the target revision, +// to improve efficiency of db operations. +func (cs *ContextStack) CommitToRevision(target int) error { + if target < 0 || target >= len(cs.cachedContexts) { + return fmt.Errorf("snapshot index %d out of bound [%d..%d)", target, 0, len(cs.cachedContexts)) + } + + // commit in order from top to bottom + for i := len(cs.cachedContexts) - 1; i > target; i-- { + if cs.cachedContexts[i].commit == nil { + return fmt.Errorf("commit function at index %d should not be nil", i) + } + cs.cachedContexts[i].commit() + } + cs.cachedContexts = cs.cachedContexts[0 : target+1] + + return nil +} + +// Snapshot pushes a new cached context to the stack, +// and returns the index of it. +func (cs *ContextStack) Snapshot() int { + i := len(cs.cachedContexts) + ctx, commit := cs.CurrentContext().CacheContext() + cs.cachedContexts = append(cs.cachedContexts, cachedContext{ctx: ctx, commit: commit}) + return i +} + +// RevertToSnapshot pops all the cached contexts after the target index (inclusive). +// the target should be snapshot index returned by `Snapshot`. +// This function panics if the index is out of bounds. +func (cs *ContextStack) RevertToSnapshot(target int) { + if target < 0 || target >= len(cs.cachedContexts) { + panic(fmt.Errorf("snapshot index %d out of bound [%d..%d)", target, 0, len(cs.cachedContexts))) + } + cs.cachedContexts = cs.cachedContexts[:target] +} + +// RevertAll discards all the cache contexts. +func (cs *ContextStack) RevertAll() { + if len(cs.cachedContexts) > 0 { + cs.RevertToSnapshot(0) + } +} diff --git a/store/cachekv/internal/btree.go b/store/cachekv/internal/btree.go new file mode 100644 index 000000000000..142f754bbd38 --- /dev/null +++ b/store/cachekv/internal/btree.go @@ -0,0 +1,80 @@ +package internal + +import ( + "bytes" + "errors" + + "github.com/tidwall/btree" +) + +const ( + // The approximate number of items and children per B-tree node. Tuned with benchmarks. + // copied from memdb. + bTreeDegree = 32 +) + +var errKeyEmpty = errors.New("key cannot be empty") + +// BTree implements the sorted cache for cachekv store, +// we don't use MemDB here because cachekv is used extensively in sdk core path, +// we need it to be as fast as possible, while `MemDB` is mainly used as a mocking db in unit tests. +// +// We choose tidwall/btree over google/btree here because it provides API to implement step iterator directly. +type BTree struct { + tree btree.BTreeG[item] +} + +// NewBTree creates a wrapper around `btree.BTreeG`. +func NewBTree() *BTree { + return &BTree{tree: *btree.NewBTreeGOptions(byKeys, btree.Options{ + Degree: bTreeDegree, + // Contract: cachekv store must not be called concurrently + NoLocks: true, + })} +} + +func (bt *BTree) Set(key, value []byte) { + bt.tree.Set(newItem(key, value)) +} + +func (bt *BTree) Get(key []byte) []byte { + i, found := bt.tree.Get(newItem(key, nil)) + if !found { + return nil + } + return i.value +} + +func (bt *BTree) Delete(key []byte) { + bt.tree.Delete(newItem(key, nil)) +} + +func (bt *BTree) Iterator(start, end []byte) (*memIterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return NewMemIterator(start, end, bt, make(map[string]struct{}), true), nil +} + +func (bt *BTree) ReverseIterator(start, end []byte) (*memIterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return NewMemIterator(start, end, bt, make(map[string]struct{}), false), nil +} + +// item is a btree item with byte slices as keys and values +type item struct { + key []byte + value []byte +} + +// byKeys compares the items by key +func byKeys(a, b item) bool { + return bytes.Compare(a.key, b.key) == -1 +} + +// newItem creates a new pair item. +func newItem(key, value []byte) item { + return item{key: key, value: value} +} diff --git a/store/cachekv/internal/btree_test.go b/store/cachekv/internal/btree_test.go new file mode 100644 index 000000000000..f85a8bbaf109 --- /dev/null +++ b/store/cachekv/internal/btree_test.go @@ -0,0 +1,202 @@ +package internal + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func TestGetSetDelete(t *testing.T) { + db := NewBTree() + + // A nonexistent key should return nil. + value := db.Get([]byte("a")) + require.Nil(t, value) + + // Set and get a value. + db.Set([]byte("a"), []byte{0x01}) + db.Set([]byte("b"), []byte{0x02}) + value = db.Get([]byte("a")) + require.Equal(t, []byte{0x01}, value) + + value = db.Get([]byte("b")) + require.Equal(t, []byte{0x02}, value) + + // Deleting a non-existent value is fine. + db.Delete([]byte("x")) + + // Delete a value. + db.Delete([]byte("a")) + + value = db.Get([]byte("a")) + require.Nil(t, value) + + db.Delete([]byte("b")) + + value = db.Get([]byte("b")) + require.Nil(t, value) +} + +func TestDBIterator(t *testing.T) { + db := NewBTree() + + for i := 0; i < 10; i++ { + if i != 6 { // but skip 6. + db.Set(int642Bytes(int64(i)), []byte{}) + } + } + + // Blank iterator keys should error + _, err := db.ReverseIterator([]byte{}, nil) + require.Equal(t, errKeyEmpty, err) + _, err = db.ReverseIterator(nil, []byte{}) + require.Equal(t, errKeyEmpty, err) + + itr, err := db.Iterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator") + + ritr, err := db.ReverseIterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator") + + itr, err = db.Iterator(nil, int642Bytes(0)) + require.NoError(t, err) + verifyIterator(t, itr, []int64(nil), "forward iterator to 0") + + ritr, err = db.ReverseIterator(int642Bytes(10), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64(nil), "reverse iterator from 10 (ex)") + + itr, err = db.Iterator(int642Bytes(0), nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 0") + + itr, err = db.Iterator(int642Bytes(1), nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 1") + + ritr, err = db.ReverseIterator(nil, int642Bytes(10)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 10 (ex)") + + ritr, err = db.ReverseIterator(nil, int642Bytes(9)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 9 (ex)") + + ritr, err = db.ReverseIterator(nil, int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 8 (ex)") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5}, "forward iterator from 5 to 6") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5}, "forward iterator from 5 to 7") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5, 7}, "forward iterator from 5 to 8") + + itr, err = db.Iterator(int642Bytes(6), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, itr, []int64(nil), "forward iterator from 6 to 7") + + itr, err = db.Iterator(int642Bytes(6), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{7}, "forward iterator from 6 to 8") + + itr, err = db.Iterator(int642Bytes(7), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{7}, "forward iterator from 7 to 8") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(5)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{4}, "reverse iterator from 5 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{5, 4}, "reverse iterator from 6 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{5, 4}, "reverse iterator from 7 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(5), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{5}, "reverse iterator from 6 (ex) to 5") + + ritr, err = db.ReverseIterator(int642Bytes(5), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{5}, "reverse iterator from 7 (ex) to 5") + + ritr, err = db.ReverseIterator(int642Bytes(6), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64(nil), "reverse iterator from 7 (ex) to 6") + + ritr, err = db.ReverseIterator(int642Bytes(10), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64(nil), "reverse iterator to 10") + + ritr, err = db.ReverseIterator(int642Bytes(6), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7}, "reverse iterator to 6") + + ritr, err = db.ReverseIterator(int642Bytes(5), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7, 5}, "reverse iterator to 5") + + ritr, err = db.ReverseIterator(int642Bytes(8), int642Bytes(9)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{8}, "reverse iterator from 9 (ex) to 8") + + ritr, err = db.ReverseIterator(int642Bytes(2), int642Bytes(4)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{3, 2}, "reverse iterator from 4 (ex) to 2") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(2)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64(nil), "reverse iterator from 2 (ex) to 4") + + // Ensure that the iterators don't panic with an empty database. + db2 := NewBTree() + + itr, err = db2.Iterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, itr, nil, "forward iterator with empty db") + + ritr, err = db2.ReverseIterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, ritr, nil, "reverse iterator with empty db") +} + +func verifyIterator(t *testing.T, itr *memIterator, expected []int64, msg string) { + i := 0 + for itr.Valid() { + key := itr.Key() + require.Equal(t, expected[i], bytes2Int64(key), "iterator: %d mismatches", i) + itr.Next() + i++ + } + require.Equal(t, i, len(expected), "expected to have fully iterated over all the elements in iter") + require.NoError(t, itr.Close()) +} + +func int642Bytes(i int64) []byte { + return sdk.Uint64ToBigEndian(uint64(i)) +} + +func bytes2Int64(buf []byte) int64 { + return int64(sdk.BigEndianToUint64(buf)) +} diff --git a/store/cachekv/internal/memiterator.go b/store/cachekv/internal/memiterator.go new file mode 100644 index 000000000000..2bceb8bc77df --- /dev/null +++ b/store/cachekv/internal/memiterator.go @@ -0,0 +1,137 @@ +package internal + +import ( + "bytes" + "errors" + + "github.com/cosmos/cosmos-sdk/store/types" + "github.com/tidwall/btree" +) + +var _ types.Iterator = (*memIterator)(nil) + +// memIterator iterates over iterKVCache items. +// if key is nil, means it was deleted. +// Implements Iterator. +type memIterator struct { + iter btree.GenericIter[item] + + start []byte + end []byte + ascending bool + lastKey []byte + deleted map[string]struct{} + valid bool +} + +func NewMemIterator(start, end []byte, items *BTree, deleted map[string]struct{}, ascending bool) *memIterator { + iter := items.tree.Iter() + var valid bool + if ascending { + if start != nil { + valid = iter.Seek(newItem(start, nil)) + } else { + valid = iter.First() + } + } else { + if end != nil { + valid = iter.Seek(newItem(end, nil)) + if !valid { + valid = iter.Last() + } else { + // end is exclusive + valid = iter.Prev() + } + } else { + valid = iter.Last() + } + } + + mi := &memIterator{ + iter: iter, + start: start, + end: end, + ascending: ascending, + lastKey: nil, + deleted: deleted, + valid: valid, + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } + + return mi +} + +func (mi *memIterator) Domain() (start []byte, end []byte) { + return mi.start, mi.end +} + +func (mi *memIterator) Close() error { + mi.iter.Release() + return nil +} + +func (mi *memIterator) Error() error { + if !mi.Valid() { + return errors.New("invalid memIterator") + } + return nil +} + +func (mi *memIterator) Valid() bool { + return mi.valid +} + +func (mi *memIterator) Next() { + mi.assertValid() + + if mi.ascending { + mi.valid = mi.iter.Next() + } else { + mi.valid = mi.iter.Prev() + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } +} + +func (mi *memIterator) keyInRange(key []byte) bool { + if mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 { + return false + } + if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 { + return false + } + return true +} + +func (mi *memIterator) Key() []byte { + return mi.iter.Item().key +} + +func (mi *memIterator) Value() []byte { + item := mi.iter.Item() + key := item.key + // We need to handle the case where deleted is modified and includes our current key + // We handle this by maintaining a lastKey object in the iterator. + // If the current key is the same as the last key (and last key is not nil / the start) + // then we are calling value on the same thing as last time. + // Therefore we don't check the mi.deleted to see if this key is included in there. + if _, ok := mi.deleted[string(key)]; ok { + if mi.lastKey == nil || !bytes.Equal(key, mi.lastKey) { + // not re-calling on old last key + return nil + } + } + mi.lastKey = key + return item.value +} + +func (mi *memIterator) assertValid() { + if err := mi.Error(); err != nil { + panic(err) + } +} diff --git a/store/cachekv/mergeiterator.go b/store/cachekv/internal/mergeiterator.go similarity index 86% rename from store/cachekv/mergeiterator.go rename to store/cachekv/internal/mergeiterator.go index a6c7a035aba0..4186a178a863 100644 --- a/store/cachekv/mergeiterator.go +++ b/store/cachekv/internal/mergeiterator.go @@ -1,4 +1,4 @@ -package cachekv +package internal import ( "bytes" @@ -18,17 +18,20 @@ type cacheMergeIterator struct { parent types.Iterator cache types.Iterator ascending bool + + valid bool } var _ types.Iterator = (*cacheMergeIterator)(nil) -func newCacheMergeIterator(parent, cache types.Iterator, ascending bool) *cacheMergeIterator { +func NewCacheMergeIterator(parent, cache types.Iterator, ascending bool) *cacheMergeIterator { iter := &cacheMergeIterator{ parent: parent, cache: cache, ascending: ascending, } + iter.valid = iter.skipUntilExistsOrInvalid() return iter } @@ -40,42 +43,38 @@ func (iter *cacheMergeIterator) Domain() (start, end []byte) { // Valid implements Iterator. func (iter *cacheMergeIterator) Valid() bool { - return iter.skipUntilExistsOrInvalid() + return iter.valid } // Next implements Iterator func (iter *cacheMergeIterator) Next() { - iter.skipUntilExistsOrInvalid() iter.assertValid() - // If parent is invalid, get the next cache item. - if !iter.parent.Valid() { + switch { + case !iter.parent.Valid(): + // If parent is invalid, get the next cache item. iter.cache.Next() - return - } - - // If cache is invalid, get the next parent item. - if !iter.cache.Valid() { + case !iter.cache.Valid(): + // If cache is invalid, get the next parent item. iter.parent.Next() - return - } - - // Both are valid. Compare keys. - keyP, keyC := iter.parent.Key(), iter.cache.Key() - switch iter.compare(keyP, keyC) { - case -1: // parent < cache - iter.parent.Next() - case 0: // parent == cache - iter.parent.Next() - iter.cache.Next() - case 1: // parent > cache - iter.cache.Next() + default: + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + switch iter.compare(keyP, keyC) { + case -1: // parent < cache + iter.parent.Next() + case 0: // parent == cache + iter.parent.Next() + iter.cache.Next() + case 1: // parent > cache + iter.cache.Next() + } } + iter.valid = iter.skipUntilExistsOrInvalid() } // Key implements Iterator func (iter *cacheMergeIterator) Key() []byte { - iter.skipUntilExistsOrInvalid() iter.assertValid() // If parent is invalid, get the cache key. @@ -106,7 +105,6 @@ func (iter *cacheMergeIterator) Key() []byte { // Value implements Iterator func (iter *cacheMergeIterator) Value() []byte { - iter.skipUntilExistsOrInvalid() iter.assertValid() // If parent is invalid, get the cache value. @@ -137,11 +135,12 @@ func (iter *cacheMergeIterator) Value() []byte { // Close implements Iterator func (iter *cacheMergeIterator) Close() error { + err1 := iter.cache.Close() if err := iter.parent.Close(); err != nil { return err } - return iter.cache.Close() + return err1 } // Error returns an error if the cacheMergeIterator is invalid defined by the diff --git a/store/cachekv/memiterator.go b/store/cachekv/memiterator.go deleted file mode 100644 index a12ff9acfd11..000000000000 --- a/store/cachekv/memiterator.go +++ /dev/null @@ -1,57 +0,0 @@ -package cachekv - -import ( - "bytes" - - dbm "github.com/tendermint/tm-db" - - "github.com/cosmos/cosmos-sdk/store/types" -) - -// memIterator iterates over iterKVCache items. -// if key is nil, means it was deleted. -// Implements Iterator. -type memIterator struct { - types.Iterator - - lastKey []byte - deleted map[string]struct{} -} - -func newMemIterator(start, end []byte, items *dbm.MemDB, deleted map[string]struct{}, ascending bool) *memIterator { - var ( - iter types.Iterator - err error - ) - - if ascending { - iter, err = items.Iterator(start, end) - } else { - iter, err = items.ReverseIterator(start, end) - } - - if err != nil { - panic(err) - } - - return &memIterator{ - Iterator: iter, - lastKey: nil, - deleted: deleted, - } -} - -func (mi *memIterator) Value() []byte { - key := mi.Iterator.Key() - // We need to handle the case where deleted is modified and includes our current key - // We handle this by maintaining a lastKey object in the iterator. - // If the current key is the same as the last key (and last key is not nil / the start) - // then we are calling value on the same thing as last time. - // Therefore we don't check the mi.deleted to see if this key is included in there. - reCallingOnOldLastKey := (mi.lastKey != nil) && bytes.Equal(key, mi.lastKey) - if _, ok := mi.deleted[string(key)]; ok && !reCallingOnOldLastKey { - return nil - } - mi.lastKey = key - return mi.Iterator.Value() -} diff --git a/store/cachekv/search_benchmark_test.go b/store/cachekv/search_benchmark_test.go index 921bff4e3864..4007c7cda202 100644 --- a/store/cachekv/search_benchmark_test.go +++ b/store/cachekv/search_benchmark_test.go @@ -4,7 +4,7 @@ import ( "strconv" "testing" - db "github.com/tendermint/tm-db" + "github.com/cosmos/cosmos-sdk/store/cachekv/internal" ) func BenchmarkLargeUnsortedMisses(b *testing.B) { @@ -39,6 +39,6 @@ func generateStore() *Store { return &Store{ cache: cache, unsortedCache: unsorted, - sortedCache: db.NewMemDB(), + sortedCache: internal.NewBTree(), } } diff --git a/store/cachekv/store.go b/store/cachekv/store.go index ebf51bdba07f..e9f9d7195f4c 100644 --- a/store/cachekv/store.go +++ b/store/cachekv/store.go @@ -10,6 +10,7 @@ import ( dbm "github.com/tendermint/tm-db" "github.com/cosmos/cosmos-sdk/internal/conv" + "github.com/cosmos/cosmos-sdk/store/cachekv/internal" "github.com/cosmos/cosmos-sdk/store/internal/kv" "github.com/cosmos/cosmos-sdk/store/tracekv" "github.com/cosmos/cosmos-sdk/store/types" @@ -30,7 +31,7 @@ type Store struct { cache map[string]*cValue deleted map[string]struct{} unsortedCache map[string]struct{} - sortedCache *dbm.MemDB // always ascending sorted + sortedCache *internal.BTree // always ascending sorted parent types.KVStore } @@ -42,7 +43,7 @@ func NewStore(parent types.KVStore) *Store { cache: make(map[string]*cValue), deleted: make(map[string]struct{}), unsortedCache: make(map[string]struct{}), - sortedCache: dbm.NewMemDB(), + sortedCache: internal.NewBTree(), parent: parent, } } @@ -102,7 +103,7 @@ func (store *Store) Write() { defer store.mtx.Unlock() if len(store.cache) == 0 && len(store.deleted) == 0 && len(store.unsortedCache) == 0 { - store.sortedCache = dbm.NewMemDB() + store.sortedCache = internal.NewBTree() return } @@ -149,7 +150,7 @@ func (store *Store) Write() { for key := range store.unsortedCache { delete(store.unsortedCache, key) } - store.sortedCache = dbm.NewMemDB() + store.sortedCache = internal.NewBTree() } // CacheWrap implements CacheWrapper. @@ -188,9 +189,9 @@ func (store *Store) iterator(start, end []byte, ascending bool) types.Iterator { } store.dirtyItems(start, end) - cache = newMemIterator(start, end, store.sortedCache, store.deleted, ascending) + cache = internal.NewMemIterator(start, end, store.sortedCache, store.deleted, ascending) - return newCacheMergeIterator(parent, cache, ascending) + return internal.NewCacheMergeIterator(parent, cache, ascending) } func findStartIndex(strL []string, startQ string) int { @@ -372,16 +373,11 @@ func (store *Store) clearUnsortedCacheSubset(unsorted []*kv.Pair, sortState sort if item.Value == nil { // deleted element, tracked by store.deleted // setting arbitrary value - if err := store.sortedCache.Set(item.Key, []byte{}); err != nil { - panic(err) - } - + store.sortedCache.Set(item.Key, []byte{}) continue } - if err := store.sortedCache.Set(item.Key, item.Value); err != nil { - panic(err) - } + store.sortedCache.Set(item.Key, item.Value) } } diff --git a/store/cachekv/store_test.go b/store/cachekv/store_test.go index f86522d5d41e..3ef99fd6f144 100644 --- a/store/cachekv/store_test.go +++ b/store/cachekv/store_test.go @@ -120,6 +120,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, nItems, i) + require.NoError(t, itr.Close()) // iterate over none itr = st.Iterator(bz("money"), nil) @@ -128,6 +129,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 0, i) + require.NoError(t, itr.Close()) // iterate over lower itr = st.Iterator(keyFmt(0), keyFmt(3)) @@ -139,6 +141,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 3, i) + require.NoError(t, itr.Close()) // iterate over upper itr = st.Iterator(keyFmt(2), keyFmt(4)) @@ -150,6 +153,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 4, i) + require.NoError(t, itr.Close()) } func TestCacheKVReverseIteratorBounds(t *testing.T) { @@ -171,6 +175,7 @@ func TestCacheKVReverseIteratorBounds(t *testing.T) { i++ } require.Equal(t, nItems, i) + require.NoError(t, itr.Close()) // iterate over none itr = st.ReverseIterator(bz("money"), nil) @@ -179,6 +184,7 @@ func TestCacheKVReverseIteratorBounds(t *testing.T) { i++ } require.Equal(t, 0, i) + require.NoError(t, itr.Close()) // iterate over lower end := 3 @@ -191,6 +197,7 @@ func TestCacheKVReverseIteratorBounds(t *testing.T) { require.Equal(t, valFmt(end-i), v) } require.Equal(t, 3, i) + require.NoError(t, itr.Close()) // iterate over upper end = 4 @@ -203,6 +210,7 @@ func TestCacheKVReverseIteratorBounds(t *testing.T) { require.Equal(t, valFmt(end-i), v) } require.Equal(t, 2, i) + require.NoError(t, itr.Close()) } func TestCacheKVMergeIteratorBasics(t *testing.T) { @@ -347,12 +355,16 @@ func TestCacheKVMergeIteratorChunks(t *testing.T) { func TestCacheKVMergeIteratorDomain(t *testing.T) { st := newCacheKVStore() - start, end := st.Iterator(nil, nil).Domain() + itr := st.Iterator(nil, nil) + start, end := itr.Domain() require.Equal(t, start, end) + require.NoError(t, itr.Close()) - start, end = st.Iterator(keyFmt(40), keyFmt(60)).Domain() + itr = st.Iterator(keyFmt(40), keyFmt(60)) + start, end = itr.Domain() require.Equal(t, keyFmt(40), start) require.Equal(t, keyFmt(60), end) + require.NoError(t, itr.Close()) start, end = st.ReverseIterator(keyFmt(0), keyFmt(80)).Domain() require.Equal(t, keyFmt(0), start) @@ -414,10 +426,27 @@ func TestNilEndIterator(t *testing.T) { } require.Equal(t, SIZE-tt.startIndex, j) + require.NoError(t, itr.Close()) }) } } +// TestIteratorDeadlock demonstrate the deadlock issue in cache store. +func TestIteratorDeadlock(t *testing.T) { + mem := dbadapter.Store{DB: dbm.NewMemDB()} + store := cachekv.NewStore(mem) + // the channel buffer is 64 and received once, so put at least 66 elements. + for i := 0; i < 66; i++ { + store.Set([]byte(fmt.Sprintf("key%d", i)), []byte{1}) + } + it := store.Iterator(nil, nil) + defer it.Close() + store.Set([]byte("key20"), []byte{1}) + // it'll be blocked here with previous version, or enable lock on btree. + it2 := store.Iterator(nil, nil) + defer it2.Close() +} + //------------------------------------------------------------------------------------------- // do some random ops @@ -500,6 +529,7 @@ func assertIterateDomain(t *testing.T, st types.KVStore, expectedN int) { i++ } require.Equal(t, expectedN, i) + require.NoError(t, itr.Close()) } func assertIterateDomainCheck(t *testing.T, st types.KVStore, mem dbm.DB, r []keyRange) { @@ -531,6 +561,8 @@ func assertIterateDomainCheck(t *testing.T, st types.KVStore, mem dbm.DB, r []ke require.False(t, itr.Valid()) require.False(t, itr2.Valid()) + require.NoError(t, itr.Close()) + require.NoError(t, itr2.Close()) } func assertIterateDomainCompare(t *testing.T, st types.KVStore, mem dbm.DB) { @@ -540,6 +572,8 @@ func assertIterateDomainCompare(t *testing.T, st types.KVStore, mem dbm.DB) { require.NoError(t, err) checkIterators(t, itr, itr2) checkIterators(t, itr2, itr) + require.NoError(t, itr.Close()) + require.NoError(t, itr2.Close()) } func checkIterators(t *testing.T, itr, itr2 types.Iterator) { diff --git a/tests/go.mod b/tests/go.mod index 3ac1703c8cc9..c88464e4eb57 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -135,6 +135,7 @@ require ( github.com/tendermint/btcd v0.1.1 // indirect github.com/tendermint/crypto v0.0.0-20191022145703-50d29ede1e15 // indirect github.com/tendermint/go-amino v0.16.0 // indirect + github.com/tidwall/btree v1.5.2 // indirect github.com/ulikunitz/xz v0.5.8 // indirect github.com/zondax/hid v0.9.1 // indirect github.com/zondax/ledger-go v0.14.0 // indirect diff --git a/tests/go.sum b/tests/go.sum index fdc7ce8724c1..d1147092ce04 100644 --- a/tests/go.sum +++ b/tests/go.sum @@ -741,6 +741,8 @@ github.com/tendermint/tendermint v0.37.0-rc2 h1:2n1em+jfbhSv6QnBj8F6KHCpbIzZCB8K github.com/tendermint/tendermint v0.37.0-rc2/go.mod h1:uYQO9DRNPeZROa9X3hJOZpYcVREDC2/HST+EiU5g2+A= github.com/tendermint/tm-db v0.6.7 h1:fE00Cbl0jayAoqlExN6oyQJ7fR/ZtoVOmvPJ//+shu8= github.com/tendermint/tm-db v0.6.7/go.mod h1:byQDzFkZV1syXr/ReXS808NxA2xvyuuVgXOJ/088L6I= +github.com/tidwall/btree v1.5.2 h1:5eA83Gfki799V3d3bJo9sWk+yL2LRoTEah3O/SA6/8w= +github.com/tidwall/btree v1.5.2/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo=