Skip to content

Commit

Permalink
expression: vectorize hash calculation during probing (pingcap#12048) (
Browse files Browse the repository at this point in the history
  • Loading branch information
sduzh authored and XiaTianliang committed Dec 21, 2019
1 parent a5ab5f9 commit 77234a7
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 19 deletions.
47 changes: 39 additions & 8 deletions executor/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,31 +530,35 @@ func BenchmarkWindowFunctions(b *testing.B) {

type hashJoinTestCase struct {
rows int
cols []*types.FieldType
concurrency int
ctx sessionctx.Context
keyIdx []int
disk bool
}

func (tc hashJoinTestCase) columns() []*expression.Column {
return []*expression.Column{
{Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeVarString)},
ret := make([]*expression.Column, 0)
for i, t := range tc.cols {
column := &expression.Column{Index: i, RetType: t}
ret = append(ret, column)
}
return ret
}

func (tc hashJoinTestCase) String() string {
return fmt.Sprintf("(rows:%v, concurency:%v, joinKeyIdx: %v, disk:%v)",
tc.rows, tc.concurrency, tc.keyIdx, tc.disk)
return fmt.Sprintf("(rows:%v, cols:%v, concurency:%v, joinKeyIdx: %v, disk:%v)",
tc.rows, tc.cols, tc.concurrency, tc.keyIdx, tc.disk)
}

func defaultHashJoinTestCase() *hashJoinTestCase {
func defaultHashJoinTestCase(cols []*types.FieldType) *hashJoinTestCase {
ctx := mock.NewContext()
ctx.GetSessionVars().InitChunkSize = variable.DefInitChunkSize
ctx.GetSessionVars().MaxChunkSize = variable.DefMaxChunkSize
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(nil, -1)
ctx.GetSessionVars().IndexLookupJoinConcurrency = 4
tc := &hashJoinTestCase{rows: 100000, concurrency: 4, ctx: ctx, keyIdx: []int{0, 1}}
tc.cols = cols
return tc
}

Expand Down Expand Up @@ -606,6 +610,8 @@ func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) {
return int64(row)
case mysql.TypeVarString:
return rawData
case mysql.TypeDouble:
return float64(row)
default:
panic("not implement")
}
Expand Down Expand Up @@ -651,8 +657,13 @@ func BenchmarkHashJoinExec(b *testing.B) {
log.SetLevel(zapcore.ErrorLevel)
defer log.SetLevel(lvl)

cols := []*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeVarString),
}

b.ReportAllocs()
cas := defaultHashJoinTestCase()
cas := defaultHashJoinTestCase(cols)
b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) {
benchmarkHashJoinExecWithCase(b, cas)
})
Expand All @@ -674,6 +685,21 @@ func BenchmarkHashJoinExec(b *testing.B) {
b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) {
benchmarkHashJoinExecWithCase(b, cas)
})

// Replace the wide string column with double column
cols = []*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeDouble),
}
cas = defaultHashJoinTestCase(cols)
b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) {
benchmarkHashJoinExecWithCase(b, cas)
})

cas.keyIdx = []int{0}
b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) {
benchmarkHashJoinExecWithCase(b, cas)
})
}

func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) {
Expand Down Expand Up @@ -732,8 +758,13 @@ func BenchmarkBuildHashTableForList(b *testing.B) {
log.SetLevel(zapcore.ErrorLevel)
defer log.SetLevel(lvl)

cols := []*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeVarString),
}

b.ReportAllocs()
cas := defaultHashJoinTestCase()
cas := defaultHashJoinTestCase(cols)
rows := []int{10, 100000}
keyIdxs := [][]int{{0, 1}, {0}}
disks := []bool{false, true}
Expand Down
8 changes: 2 additions & 6 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,8 @@ func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContex
// GetMatchedRows get matched rows from probeRow. It can be called
// in multiple goroutines while each goroutine should keep its own
// h and buf.
func (c *hashRowContainer) GetMatchedRows(probeRow chunk.Row, hCtx *hashContext) (matched []chunk.Row, err error) {
hasNull, key, err := c.getJoinKeyFromChkRow(c.sc, probeRow, hCtx)
if err != nil || hasNull {
return
}
innerPtrs := c.hashTable.Get(key)
func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, hCtx *hashContext) (matched []chunk.Row, err error) {
innerPtrs := c.hashTable.Get(probeKey)
if len(innerPtrs) == 0 {
return
}
Expand Down
2 changes: 1 addition & 1 deletion executor/hash_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (s *pkgTestSuite) testHashRowContainer(c *C, hashFunc func() hash.Hash64, s
}
probeCtx.hasNull = make([]bool, 1)
probeCtx.hashVals = append(hCtx.hashVals, hashFunc())
matched, err := rowContainer.GetMatchedRows(probeRow, probeCtx)
matched, err := rowContainer.GetMatchedRows(hCtx.hashVals[1].Sum64(), probeRow, probeCtx)
c.Assert(err, IsNil)
c.Assert(len(matched), Equals, 2)
c.Assert(matched[0].GetDatumRow(colTypes), DeepEquals, chk0.GetRow(1).GetDatumRow(colTypes))
Expand Down
20 changes: 16 additions & 4 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/stringutil"
)
Expand Down Expand Up @@ -362,9 +363,9 @@ func (e *HashJoinExec) runJoinWorker(workerID uint, probeKeyColIdx []int) {
}
}

func (e *HashJoinExec) joinMatchedProbeSideRow2Chunk(workerID uint, probeSideRow chunk.Row, hCtx *hashContext,
func (e *HashJoinExec) joinMatchedProbeSideRow2Chunk(workerID uint, probeKey uint64, probeSideRow chunk.Row, hCtx *hashContext,
joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) {
buildSideRows, err := e.rowContainer.GetMatchedRows(probeSideRow, hCtx)
buildSideRows, err := e.rowContainer.GetMatchedRows(probeKey, probeSideRow, hCtx)
if err != nil {
joinResult.err = err
return false, joinResult
Expand Down Expand Up @@ -419,11 +420,22 @@ func (e *HashJoinExec) join2Chunk(workerID uint, probeSideChk *chunk.Chunk, hCtx
joinResult.err = err
return false, joinResult
}

hCtx.initHash(probeSideChk.NumRows())
for _, i := range hCtx.keyColIdx {
err = codec.HashChunkSelected(e.rowContainer.sc, hCtx.hashVals, probeSideChk, hCtx.allTypes[i], i, hCtx.buf, hCtx.hasNull, selected)
if err != nil {
joinResult.err = err
return false, joinResult
}
}

for i := range selected {
if !selected[i] { // process unmatched probe side rows
if !selected[i] || hCtx.hasNull[i] { // process unmatched probe side rows
e.joiners[workerID].onMissMatch(false, probeSideChk.GetRow(i), joinResult.chk)
} else { // process matched probe side rows
ok, joinResult = e.joinMatchedProbeSideRow2Chunk(workerID, probeSideChk.GetRow(i), hCtx, joinResult)
probeKey, probeRow := hCtx.hashVals[i].Sum64(), probeSideChk.GetRow(i)
ok, joinResult = e.joinMatchedProbeSideRow2Chunk(workerID, probeKey, probeRow, hCtx, joinResult)
if !ok {
return false, joinResult
}
Expand Down
40 changes: 40 additions & 0 deletions util/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,23 @@ func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *type

// HashChunkColumns writes the encoded value of each row's column, which of index `colIdx`, to h.
func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.Chunk, tp *types.FieldType, colIdx int, buf []byte, isNull []bool) (err error) {
return HashChunkSelected(sc, h, chk, tp, colIdx, buf, isNull, nil)
}

// HashChunkSelected writes the encoded value of selected row's column, which of index `colIdx`, to h.
// sel indicates which rows are selected. If it is nil, all rows are selected.
func HashChunkSelected(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.Chunk, tp *types.FieldType, colIdx int, buf []byte,
isNull, sel []bool) (err error) {
var b []byte
column := chk.Column(colIdx)
rows := chk.NumRows()
switch tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
i64s := column.Int64s()
for i, v := range i64s {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -386,6 +396,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
case mysql.TypeFloat:
f32s := column.Float32s()
for i, f := range f32s {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -403,6 +416,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
case mysql.TypeDouble:
f64s := column.Float64s()
for i, f := range f64s {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -418,6 +434,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -434,6 +453,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
ts := column.Times()
for i, t := range ts {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand Down Expand Up @@ -462,6 +484,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeDuration:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -479,6 +504,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
case mysql.TypeNewDecimal:
ds := column.Decimals()
for i, d := range ds {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -498,6 +526,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeEnum:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -514,6 +545,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeSet:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -530,6 +564,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeBit:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand All @@ -548,6 +585,9 @@ func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.
}
case mysql.TypeJSON:
for i := 0; i < rows; i++ {
if sel != nil && !sel[i] {
continue
}
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
Expand Down

0 comments on commit 77234a7

Please sign in to comment.