diff --git a/CHANGELOG.md b/CHANGELOG.md index b303a9c90..c44adc479 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Improvements + +- [memdb] [\#53](https://github.com/tendermint/tm-db/pull/53) Use a B-tree for storage, which significantly improves range scan performance + ## 0.4.1 **2020-2-26** diff --git a/backend_test.go b/backend_test.go index 9bba5f290..1d530ecb0 100644 --- a/backend_test.go +++ b/backend_test.go @@ -204,6 +204,7 @@ func TestDBIterator(t *testing.T) { for dbType := range backends { t.Run(fmt.Sprintf("%v", dbType), func(t *testing.T) { testDBIterator(t, dbType) + testDBIteratorBlankKey(t, dbType) }) } } @@ -311,6 +312,18 @@ func testDBIterator(t *testing.T, backend BackendType) { verifyIterator(t, ritr, []int64(nil), "reverse iterator from 7 (ex) to 6") + ritr, err = db.ReverseIterator(int642Bytes(10), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64(nil), "reverse iterator to 10") + + ritr, err = db.ReverseIterator(int642Bytes(6), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7}, "reverse iterator to 6") + + ritr, err = db.ReverseIterator(int642Bytes(5), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7, 5}, "reverse iterator to 5") + // verifyIterator(t, db.Iterator(int642Bytes(0), int642Bytes(1)), []int64{0}, "forward iterator from 0 to 1") ritr, err = db.ReverseIterator(int642Bytes(8), int642Bytes(9)) @@ -329,7 +342,56 @@ func testDBIterator(t *testing.T, backend BackendType) { require.NoError(t, err) verifyIterator(t, ritr, []int64(nil), "reverse iterator from 2 (ex) to 4") +} + +func testDBIteratorBlankKey(t *testing.T, backend BackendType) { + name := fmt.Sprintf("test_%x", randStr(12)) + dir := os.TempDir() + db := NewDB(name, backend, dir) + defer cleanupDBDir(dir, name) + + err := db.Set([]byte(""), []byte{0}) + require.NoError(t, err) + err = db.Set([]byte("a"), []byte{1}) + require.NoError(t, err) + err = db.Set([]byte("b"), []byte{2}) + require.NoError(t, err) + + value, err := db.Get([]byte("")) + require.NoError(t, err) + assert.Equal(t, []byte{0}, value) + + i, err := db.Iterator(nil, nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"", "a", "b"}, "forward") + i, err = db.Iterator([]byte(""), nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"", "a", "b"}, "forward from blank") + + i, err = db.Iterator([]byte("a"), nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"a", "b"}, "forward from a") + + i, err = db.Iterator([]byte(""), []byte("b")) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"", "a"}, "forward from blank to b") + + i, err = db.ReverseIterator(nil, nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"b", "a", ""}, "reverse") + + i, err = db.ReverseIterator([]byte(""), nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"b", "a", ""}, "reverse to blank") + + i, err = db.ReverseIterator([]byte(""), []byte("a")) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{""}, "reverse to blank from a") + + i, err = db.ReverseIterator([]byte("a"), nil) + require.NoError(t, err) + verifyIteratorStrings(t, i, []string{"b", "a"}, "reverse to a") } func verifyIterator(t *testing.T, itr Iterator, expected []int64, msg string) { @@ -341,3 +403,13 @@ func verifyIterator(t *testing.T, itr Iterator, expected []int64, msg string) { } assert.Equal(t, expected, list, msg) } + +func verifyIteratorStrings(t *testing.T, itr Iterator, expected []string, msg string) { + var list []string + for itr.Valid() { + key := itr.Key() + list = append(list, string(key)) + itr.Next() + } + assert.Equal(t, expected, list, msg) +} diff --git a/common_test.go b/common_test.go index d17c49e00..895acb198 100644 --- a/common_test.go +++ b/common_test.go @@ -201,6 +201,38 @@ func (mockIterator) Error() error { func (mockIterator) Close() { } +func benchmarkRangeScans(b *testing.B, db DB, dbSize int64) { + b.StopTimer() + + rangeSize := int64(10000) + if dbSize < rangeSize { + b.Errorf("db size %v cannot be less than range size %v", dbSize, rangeSize) + } + + for i := int64(0); i < dbSize; i++ { + bytes := int642Bytes(i) + err := db.Set(bytes, bytes) + if err != nil { + // require.NoError() is very expensive (according to profiler), so check manually + b.Fatal(b, err) + } + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + start := rand.Int63n(dbSize - rangeSize) + end := start + rangeSize + iter, err := db.Iterator(int642Bytes(start), int642Bytes(end)) + require.NoError(b, err) + count := 0 + for ; iter.Valid(); iter.Next() { + count++ + } + iter.Close() + require.EqualValues(b, rangeSize, count) + } +} + func benchmarkRandomReadsWrites(b *testing.B, db DB) { b.StopTimer() @@ -217,23 +249,29 @@ func benchmarkRandomReadsWrites(b *testing.B, db DB) { for i := 0; i < b.N; i++ { // Write something { - idx := int64(rand.Int()) % numItems // nolint:gosec testing file, so accepting weak random number generator + idx := rand.Int63n(numItems) // nolint:gosec testing file, so accepting weak random number generator internal[idx]++ val := internal[idx] idxBytes := int642Bytes(idx) valBytes := int642Bytes(val) //fmt.Printf("Set %X -> %X\n", idxBytes, valBytes) err := db.Set(idxBytes, valBytes) - b.Error(err) + if err != nil { + // require.NoError() is very expensive (according to profiler), so check manually + b.Fatal(b, err) + } } // Read something { - idx := int64(rand.Int()) % numItems // nolint:gosec testing file, so accepting weak random number generator + idx := rand.Int63n(numItems) // nolint:gosec testing file, so accepting weak random number generator valExp := internal[idx] idxBytes := int642Bytes(idx) valBytes, err := db.Get(idxBytes) - b.Error(err) + if err != nil { + // require.NoError() is very expensive (according to profiler), so check manually + b.Fatal(b, err) + } //fmt.Printf("Get %X -> %X\n", idxBytes, valBytes) if valExp == 0 { if !bytes.Equal(valBytes, nil) { diff --git a/go.mod b/go.mod index 6d3062630..56e87e583 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 // indirect github.com/gogo/protobuf v1.3.1 + github.com/google/btree v1.0.0 github.com/jmhodges/levigo v1.0.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.5.1 diff --git a/go.sum b/go.sum index 3c3d2e5ce..5b70eeea6 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= diff --git a/mem_db.go b/mem_db.go index a88e764ef..286145810 100644 --- a/mem_db.go +++ b/mem_db.go @@ -1,11 +1,47 @@ package db import ( + "bytes" + "context" "fmt" - "sort" "sync" + + "github.com/google/btree" +) + +const ( + // The approximate number of items and children per B-tree node. Tuned with benchmarks. + bTreeDegree = 32 + + // Size of the channel buffer between traversal goroutine and iterator. Using an unbuffered + // channel causes two context switches per item sent, while buffering allows more work per + // context switch. Tuned with benchmarks. + chBufferSize = 64 ) +// item is a btree.Item with byte slices as keys and values +type item struct { + key []byte + value []byte +} + +// Less implements btree.Item. +func (i *item) Less(other btree.Item) bool { + // this considers nil == []byte{}, but that's ok since we handle nil endpoints + // in iterators specially anyway + return bytes.Compare(i.key, other.(*item).key) == -1 +} + +// newKey creates a new key item +func newKey(key []byte) *item { + return &item{key: key} +} + +// newPair creates a new pair item +func newPair(key, value []byte) *item { + return &item{key: key, value: value} +} + func init() { registerDBCreator(MemDBBackend, func(name, dir string) (DB, error) { return NewMemDB(), nil @@ -15,13 +51,13 @@ func init() { var _ DB = (*MemDB)(nil) type MemDB struct { - mtx sync.Mutex - db map[string][]byte + mtx sync.Mutex + btree *btree.BTree } func NewMemDB() *MemDB { database := &MemDB{ - db: make(map[string][]byte), + btree: btree.New(bTreeDegree), } return database } @@ -37,8 +73,11 @@ func (db *MemDB) Get(key []byte) ([]byte, error) { defer db.mtx.Unlock() key = nonNilBytes(key) - value := db.db[string(key)] - return value, nil + i := db.btree.Get(newKey(key)) + if i != nil { + return i.(*item).value, nil + } + return nil, nil } // Implements DB. @@ -47,8 +86,7 @@ func (db *MemDB) Has(key []byte) (bool, error) { defer db.mtx.Unlock() key = nonNilBytes(key) - _, ok := db.db[string(key)] - return ok, nil + return db.btree.Has(newKey(key)), nil } // Implements DB. @@ -79,7 +117,7 @@ func (db *MemDB) SetNoLockSync(key []byte, value []byte) { key = nonNilBytes(key) value = nonNilBytes(value) - db.db[string(key)] = value + db.btree.ReplaceOrInsert(newPair(key, value)) } // Implements DB. @@ -109,7 +147,7 @@ func (db *MemDB) DeleteNoLock(key []byte) { func (db *MemDB) DeleteNoLockSync(key []byte) { key = nonNilBytes(key) - delete(db.db, string(key)) + db.btree.Delete(newKey(key)) } // Implements DB. @@ -127,9 +165,11 @@ func (db *MemDB) Print() error { db.mtx.Lock() defer db.mtx.Unlock() - for key, value := range db.db { - fmt.Printf("[%X]:\t[%X]\n", []byte(key), value) - } + db.btree.Ascend(func(i btree.Item) bool { + item := i.(*item) + fmt.Printf("[%X]:\t[%X]\n", item.key, item.value) + return true + }) return nil } @@ -140,15 +180,12 @@ func (db *MemDB) Stats() map[string]string { stats := make(map[string]string) stats["database.type"] = "memDB" - stats["database.size"] = fmt.Sprintf("%d", len(db.db)) + stats["database.size"] = fmt.Sprintf("%d", db.btree.Len()) return stats } // Implements DB. func (db *MemDB) NewBatch() Batch { - db.mtx.Lock() - defer db.mtx.Unlock() - return &memBatch{db, nil} } @@ -160,8 +197,7 @@ func (db *MemDB) Iterator(start, end []byte) (Iterator, error) { db.mtx.Lock() defer db.mtx.Unlock() - keys := db.getSortedKeys(start, end, false) - return newMemDBIterator(db, keys, start, end), nil + return newMemDBIterator(db.btree, start, end, false), nil } // Implements DB. @@ -169,101 +205,133 @@ func (db *MemDB) ReverseIterator(start, end []byte) (Iterator, error) { db.mtx.Lock() defer db.mtx.Unlock() - keys := db.getSortedKeys(start, end, true) - return newMemDBIterator(db, keys, start, end), nil + return newMemDBIterator(db.btree, start, end, true), nil } -// We need a copy of all of the keys. -// Not the best, but probably not a bottleneck depending. type memDBIterator struct { - db DB - cur int - keys []string - start []byte - end []byte + ch <-chan *item + cancel context.CancelFunc + item *item + start []byte + end []byte } var _ Iterator = (*memDBIterator)(nil) -// Keys is expected to be in reverse order for reverse iterators. -func newMemDBIterator(db DB, keys []string, start, end []byte) *memDBIterator { - return &memDBIterator{ - db: db, - cur: 0, - keys: keys, - start: start, - end: end, +func newMemDBIterator(bt *btree.BTree, start []byte, end []byte, reverse bool) *memDBIterator { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan *item, chBufferSize) + iter := &memDBIterator{ + ch: ch, + cancel: cancel, + start: start, + end: end, } -} -// Implements Iterator. -func (itr *memDBIterator) Domain() ([]byte, []byte) { - return itr.start, itr.end -} + go func() { + // Because we use [start, end) for reverse ranges, while btree uses (start, end], we need + // the following variables to handle some reverse iteration conditions ourselves. + var ( + skipEqual []byte + abortLessThan []byte + ) + visitor := func(i btree.Item) bool { + item := i.(*item) + if skipEqual != nil && bytes.Equal(item.key, skipEqual) { + skipEqual = nil + return true + } + if abortLessThan != nil && bytes.Compare(item.key, abortLessThan) == -1 { + return false + } + select { + case <-ctx.Done(): + return false + case ch <- item: + return true + } + } + s := newKey(start) + e := newKey(end) + switch { + case start == nil && end == nil && !reverse: + bt.Ascend(visitor) + case start == nil && end == nil && reverse: + bt.Descend(visitor) + case end == nil && !reverse: + // must handle this specially, since nil is considered less than anything else + bt.AscendGreaterOrEqual(s, visitor) + case !reverse: + bt.AscendRange(s, e, visitor) + case end == nil: + // abort after start, since we use [start, end) while btree uses (start, end] + abortLessThan = s.key + bt.Descend(visitor) + default: + // skip end and abort after start, since we use [start, end) while btree uses (start, end] + skipEqual = e.key + abortLessThan = s.key + bt.DescendLessOrEqual(e, visitor) + } + close(ch) + }() + + // prime the iterator with the first value, if any + if item, ok := <-ch; ok { + iter.item = item + } -// Implements Iterator. -func (itr *memDBIterator) Valid() bool { - return 0 <= itr.cur && itr.cur < len(itr.keys) + return iter } -// Implements Iterator. -func (itr *memDBIterator) Next() { - itr.assertIsValid() - itr.cur++ +// Close implements Iterator. +func (i *memDBIterator) Close() { + i.cancel() + for range i.ch { // drain channel + } + i.item = nil } -// Implements Iterator. -func (itr *memDBIterator) Key() []byte { - itr.assertIsValid() - return []byte(itr.keys[itr.cur]) +// Domain implements Iterator. +func (i *memDBIterator) Domain() ([]byte, []byte) { + return i.start, i.end } -// Implements Iterator. -func (itr *memDBIterator) Value() []byte { - itr.assertIsValid() - key := []byte(itr.keys[itr.cur]) - bytes, err := itr.db.Get(key) - if err != nil { - return nil - } - return bytes +// Valid implements Iterator. +func (i *memDBIterator) Valid() bool { + return i.item != nil } -func (itr *memDBIterator) Error() error { - return nil +// Next implements Iterator. +func (i *memDBIterator) Next() { + item, ok := <-i.ch + switch { + case ok: + i.item = item + case i.item == nil: + panic("called Next() on invalid iterator") + default: + i.item = nil + } } -// Implements Iterator. -func (itr *memDBIterator) Close() { - itr.keys = nil - itr.db = nil +// Error implements Iterator. +func (i *memDBIterator) Error() error { + return nil // famous last words } -func (itr *memDBIterator) assertIsValid() { - if !itr.Valid() { - panic("memDBIterator is invalid") +// Key implements Iterator. +func (i *memDBIterator) Key() []byte { + if i.item == nil { + panic("called Key() on invalid iterator") } + return i.item.key } -//---------------------------------------- -// Misc. - -func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string { - keys := []string{} - for key := range db.db { - inDomain := IsKeyInDomain([]byte(key), start, end) - if inDomain { - keys = append(keys, key) - } - } - sort.Strings(keys) - if reverse { - nkeys := len(keys) - for i := 0; i < nkeys/2; i++ { - temp := keys[i] - keys[i] = keys[nkeys-i-1] - keys[nkeys-i-1] = temp - } +// Value implements Iterator. +func (i *memDBIterator) Value() []byte { + if i.item == nil { + panic("called Value() on invalid iterator") } - return keys + return i.item.value } diff --git a/mem_db_test.go b/mem_db_test.go index 7f6468ee6..ee2eab9a7 100644 --- a/mem_db_test.go +++ b/mem_db_test.go @@ -32,3 +32,24 @@ func TestMemDB_Iterator(t *testing.T) { itr.Next() assert.False(t, itr.Valid()) } + +func BenchmarkMemDBRangeScans1M(b *testing.B) { + db := NewMemDB() + defer db.Close() + + benchmarkRangeScans(b, db, int64(1e6)) +} + +func BenchmarkMemDBRangeScans10M(b *testing.B) { + db := NewMemDB() + defer db.Close() + + benchmarkRangeScans(b, db, int64(10e6)) +} + +func BenchmarkMemDBRandomReadsWrites(b *testing.B) { + db := NewMemDB() + defer db.Close() + + benchmarkRandomReadsWrites(b, db) +}