Skip to content

Commit

Permalink
statistics: support building CMSketch with Top N (#10163)
Browse files Browse the repository at this point in the history
  • Loading branch information
erjiaqing authored and zz-jason committed Apr 24, 2019
1 parent a28d877 commit fa2d6f0
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 25 deletions.
198 changes: 185 additions & 13 deletions statistics/cmsketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,41 @@
package statistics

import (
"bytes"
"math"
"sort"

"github.com/cznic/mathutil"
"github.com/cznic/sortutil"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tipb/go-tipb"
"github.com/spaolacci/murmur3"
)

// topNThreshold is the minimum ratio of the number of topn elements in CMSketch, 10 means 1 / 10 = 10%.
const topNThreshold = uint64(10)

// CMSketch is used to estimate point queries.
// Refer: https://en.wikipedia.org/wiki/Count-min_sketch
type CMSketch struct {
depth int32
width int32
depth int32
width int32
count uint64 // TopN is not counted in count
defaultValue uint64 // In sampled data, if cmsketch returns a small value (less than avg value / 2), then this will returned.
table [][]uint32
topN map[uint64][]topNMeta
}

// topNMeta is a simple counter used by BuildTopN
type topNMeta struct {
h1 uint64
h2 uint64
data []byte
count uint64
table [][]uint32
}

// NewCMSketch returns a new CM sketch.
Expand All @@ -44,29 +60,173 @@ func NewCMSketch(d, w int32) *CMSketch {
return &CMSketch{depth: d, width: w, table: tbl}
}

// topNHelper wraps some variables used when building cmsketch with top n.
type topNHelper struct {
sampleSize uint64
numTop uint32
counter map[hack.MutableString]uint64
sorted []uint64
onlyOnceItems uint64
sumTopN uint64
lastVal uint64
}

func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper {
counter := make(map[hack.MutableString]uint64)
for i := range sample {
counter[hack.String(sample[i])]++
}
sorted, onlyOnceItems := make([]uint64, 0, len(counter)), uint64(0)
for _, cnt := range counter {
sorted = append(sorted, cnt)
if cnt == 1 {
onlyOnceItems++
}
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i] > sorted[j]
})

var (
// last is the last element in top N index should occurres atleast `last` times.
last uint64
sumTopN uint64
sampleNDV = uint32(len(sorted))
)
numTop = mathutil.MinUint32(sampleNDV, numTop) // Ensure numTop no larger than sampNDV.
// Only element whose frequency is not smaller than 2/3 multiples the
// frequency of the n-th element are added to the TopN statistics. We chose
// 2/3 as an empirical value because the average cardinality estimation
// error is relatively small compared with 1/2.
for i := uint32(0); i < sampleNDV && i < numTop*2; i++ {
if i >= numTop && sorted[i]*3 < sorted[numTop-1]*2 && last != sorted[i] {
break
}
last = sorted[i]
sumTopN += sorted[i]
}

return &topNHelper{uint64(len(sample)), numTop, counter, sorted, onlyOnceItems, sumTopN, last}
}

// NewCMSketchWithTopN returns a new CM sketch with TopN elements.
func NewCMSketchWithTopN(d, w int32, sample [][]byte, numTop uint32, rowCount uint64) *CMSketch {
helper := newTopNHelper(sample, numTop)
// rowCount is not a accurate value when fast analyzing
// In some cases, if user triggers fast analyze when rowCount is close to sampleSize, unexpected bahavior might happen.
rowCount = mathutil.MaxUint64(rowCount, uint64(len(sample)))
estimateNDV, scaleRatio := calculateEstimateNDV(helper, rowCount)
c := buildCMSWithTopN(helper, d, w, scaleRatio)
c.calculateDefaultVal(helper, estimateNDV, scaleRatio, rowCount)
return c
}

func buildCMSWithTopN(helper *topNHelper, d, w int32, scaleRatio uint64) (c *CMSketch) {
c, helper.sumTopN, helper.numTop = NewCMSketch(d, w), 0, 0
enableTopN := helper.sampleSize/topNThreshold <= helper.sumTopN
if enableTopN {
c.topN = make(map[uint64][]topNMeta)
}
for counterKey, cnt := range helper.counter {
data, scaledCount := hack.Slice(string(counterKey)), cnt*scaleRatio
if enableTopN && cnt >= helper.lastVal {
h1, h2 := murmur3.Sum128(data)
c.topN[h1] = append(c.topN[h1], topNMeta{h1, h2, data, scaledCount})
helper.sumTopN += scaledCount
helper.numTop++
} else {
c.updateBytesWithDelta(data, scaledCount)
}
}
return
}

func (c *CMSketch) calculateDefaultVal(helper *topNHelper, estimateNDV, scaleRatio, rowCount uint64) {
sampleNDV := uint64(len(helper.sorted))
if rowCount <= helper.sumTopN {
c.defaultValue = 1
} else if estimateNDV <= uint64(helper.numTop) {
c.defaultValue = 1
} else if estimateNDV+helper.onlyOnceItems <= uint64(sampleNDV) {
c.defaultValue = 1
} else {
estimateRemainingCount := rowCount - (helper.sampleSize-uint64(helper.onlyOnceItems))*scaleRatio
c.defaultValue = estimateRemainingCount / (estimateNDV - uint64(sampleNDV) + helper.onlyOnceItems)
}
}

// queryAddTopN TopN adds count to CMSketch.topN if exists, and returns the count of such elements after insert.
// If such elements does not in topn elements, nothing will happen and false will be returned.
func (c *CMSketch) updateTopNWithDelta(h1, h2 uint64, d []byte, delta uint64) bool {
if c.topN == nil {
return false
}
for _, cnt := range c.topN[h1] {
if cnt.h2 == h2 && bytes.Equal(d, cnt.data) {
cnt.count += delta
return true
}
}
return false
}

func (c *CMSketch) queryTopN(h1, h2 uint64, d []byte) (uint64, bool) {
if c.topN == nil {
return 0, false
}
for _, cnt := range c.topN[h1] {
if cnt.h2 == h2 && bytes.Equal(d, cnt.data) {
return cnt.count, true
}
}
return 0, false
}

// InsertBytes inserts the bytes value into the CM Sketch.
func (c *CMSketch) InsertBytes(bytes []byte) {
c.count++
c.updateBytesWithDelta(bytes, 1)
}

// updateBytesWithDelta adds the bytes value into the CM Sketch by delta.
func (c *CMSketch) updateBytesWithDelta(bytes []byte, delta uint64) {
h1, h2 := murmur3.Sum128(bytes)
if c.updateTopNWithDelta(h1, h2, bytes, delta) {
return
}
c.count += delta
for i := range c.table {
j := (h1 + h2*uint64(i)) % uint64(c.width)
c.table[i][j]++
c.table[i][j] += uint32(delta)
}
}

func (c *CMSketch) considerDefVal(cnt uint64) bool {
return cnt < 2*(c.count/uint64(c.width)) && c.defaultValue > 0
}

// setValue sets the count for value that hashed into (h1, h2).
func (c *CMSketch) setValue(h1, h2 uint64, count uint32) {
oriCount := c.queryHashValue(h1, h2)
c.count += uint64(count) - uint64(oriCount)
if c.considerDefVal(oriCount) {
// We should update c.defaultValue if we used c.defaultValue when getting the estimate count.
// This should make estimation better, remove this line if it does not work as expected.
c.defaultValue = uint64(float64(c.defaultValue)*0.95 + float64(c.defaultValue)*0.05)
if c.defaultValue == 0 {
// c.defaultValue never guess 0 since we are using a sampled data.
c.defaultValue = 1
}
}

c.count += uint64(count) - oriCount
// let it overflow naturally
deltaCount := count - oriCount
deltaCount := count - uint32(oriCount)
for i := range c.table {
j := (h1 + h2*uint64(i)) % uint64(c.width)
c.table[i][j] = c.table[i][j] + deltaCount
}
}

func (c *CMSketch) queryValue(sc *stmtctx.StatementContext, val types.Datum) (uint32, error) {
func (c *CMSketch) queryValue(sc *stmtctx.StatementContext, val types.Datum) (uint64, error) {
bytes, err := codec.EncodeValue(sc, nil, val)
if err != nil {
return 0, errors.Trace(err)
Expand All @@ -75,12 +235,15 @@ func (c *CMSketch) queryValue(sc *stmtctx.StatementContext, val types.Datum) (ui
}

// QueryBytes is used to query the count of specified bytes.
func (c *CMSketch) QueryBytes(bytes []byte) uint32 {
h1, h2 := murmur3.Sum128(bytes)
func (c *CMSketch) QueryBytes(d []byte) uint64 {
h1, h2 := murmur3.Sum128(d)
if count, ok := c.queryTopN(h1, h2, d); ok {
return count
}
return c.queryHashValue(h1, h2)
}

func (c *CMSketch) queryHashValue(h1, h2 uint64) uint32 {
func (c *CMSketch) queryHashValue(h1, h2 uint64) uint64 {
vals := make([]uint32, c.depth)
min := uint32(math.MaxUint32)
for i := range c.table {
Expand All @@ -98,16 +261,23 @@ func (c *CMSketch) queryHashValue(h1, h2 uint64) uint32 {
sort.Sort(sortutil.Uint32Slice(vals))
res := vals[(c.depth-1)/2] + (vals[c.depth/2]-vals[(c.depth-1)/2])/2
if res > min {
return min
res = min
}
if c.considerDefVal(uint64(res)) {
return c.defaultValue
}
return res
return uint64(res)
}

// MergeCMSketch merges two CM Sketch.
// Call with CMSketch with Top-N initialized may downgrade the result
func (c *CMSketch) MergeCMSketch(rc *CMSketch) error {
if c.depth != rc.depth || c.width != rc.width {
return errors.New("Dimensions of Count-Min Sketch should be the same")
}
if c.topN != nil || rc.topN != nil {
return errors.New("CMSketch with Top-N does not support merge")
}
c.count += rc.count
for i := range c.table {
for j := range c.table[i] {
Expand All @@ -118,6 +288,7 @@ func (c *CMSketch) MergeCMSketch(rc *CMSketch) error {
}

// CMSketchToProto converts CMSketch to its protobuf representation.
// TODO: Encode/Decode cmsketch with Top-N
func CMSketchToProto(c *CMSketch) *tipb.CMSketch {
protoSketch := &tipb.CMSketch{Rows: make([]*tipb.CMSketchRow, c.depth)}
for i := range c.table {
Expand All @@ -130,6 +301,7 @@ func CMSketchToProto(c *CMSketch) *tipb.CMSketch {
}

// CMSketchFromProto converts CMSketch from its protobuf representation.
// TODO: Encode/Decode cmsketch with Top-N
func CMSketchFromProto(protoSketch *tipb.CMSketch) *CMSketch {
if protoSketch == nil {
return nil
Expand Down
Loading

0 comments on commit fa2d6f0

Please sign in to comment.