Skip to content

Commit

Permalink
executor: vectorize hash calculation in hashJoin (#12048) (#12076)
Browse files Browse the repository at this point in the history
  • Loading branch information
sduzh authored and sre-bot committed Sep 11, 2019
1 parent 963f182 commit d29751c
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 19 deletions.
51 changes: 37 additions & 14 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"hash"
"hash/fnv"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -47,8 +48,28 @@ const (
type hashContext struct {
allTypes []*types.FieldType
keyColIdx []int
h hash.Hash64
buf []byte
hashVals []hash.Hash64
hasNull []bool
}

func (hc *hashContext) initHash(rows int) {
if hc.buf == nil {
hc.buf = make([]byte, 1)
}

if len(hc.hashVals) < rows {
hc.hasNull = make([]bool, rows)
hc.hashVals = make([]hash.Hash64, rows)
for i := 0; i < rows; i++ {
hc.hashVals[i] = fnv.New64()
}
} else {
for i := 0; i < rows; i++ {
hc.hasNull[i] = false
hc.hashVals[i].Reset()
}
}
}

// hashRowContainer handles the rows and the hash map of a table.
Expand Down Expand Up @@ -133,22 +154,24 @@ func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx
// value of hash table: RowPtr of the corresponded row
func (c *hashRowContainer) PutChunk(chk *chunk.Chunk) error {
chkIdx := uint32(c.records.NumChunks())
c.records.Add(chk)
var (
hasNull bool
err error
key uint64
)
numRows := chk.NumRows()
for j := 0; j < numRows; j++ {
hasNull, key, err = c.getJoinKeyFromChkRow(c.sc, chk.GetRow(j), c.hCtx)

c.records.Add(chk)
c.hCtx.initHash(numRows)

hCtx := c.hCtx
for _, colIdx := range c.hCtx.keyColIdx {
err := codec.HashChunkColumns(c.sc, hCtx.hashVals, chk, hCtx.allTypes[colIdx], colIdx, hCtx.buf, hCtx.hasNull)
if err != nil {
return errors.Trace(err)
}
if hasNull {
}
for i := 0; i < numRows; i++ {
if c.hCtx.hasNull[i] {
continue
}
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)}
key := c.hCtx.hashVals[i].Sum64()
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)}
c.hashTable.Put(key, rowPtr)
}
return nil
Expand All @@ -161,9 +184,9 @@ func (*hashRowContainer) getJoinKeyFromChkRow(sc *stmtctx.StatementContext, row
return true, 0, nil
}
}
hCtx.h.Reset()
err = codec.HashChunkRow(sc, hCtx.h, row, hCtx.allTypes, hCtx.keyColIdx, hCtx.buf)
return false, hCtx.h.Sum64(), err
hCtx.initHash(1)
err = codec.HashChunkRow(sc, hCtx.hashVals[0], row, hCtx.allTypes, hCtx.keyColIdx, hCtx.buf)
return false, hCtx.hashVals[0].Sum64(), err
}

func (c hashRowContainer) Len() int {
Expand Down
5 changes: 0 additions & 5 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package executor
import (
"context"
"fmt"
"hash/fnv"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -334,8 +333,6 @@ func (e *HashJoinExec) runJoinWorker(workerID uint, outerKeyColIdx []int) {
hCtx := &hashContext{
allTypes: retTypes(e.outerExec),
keyColIdx: outerKeyColIdx,
h: fnv.New64(),
buf: make([]byte, 1),
}
for ok := true; ok; {
if e.finished.Load().(bool) {
Expand Down Expand Up @@ -506,8 +503,6 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk)
hCtx := &hashContext{
allTypes: allTypes,
keyColIdx: innerKeyColIdx,
h: fnv.New64(),
buf: make([]byte, 1),
}
initList := chunk.NewList(allTypes, e.initCap, e.maxChunkSize)
e.rowContainer = newHashRowContainer(e.ctx, int(e.innerEstCount), hCtx, initList)
Expand Down
215 changes: 215 additions & 0 deletions util/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package codec
import (
"bytes"
"encoding/binary"
"hash"
"io"
"time"
"unsafe"
Expand Down Expand Up @@ -45,6 +46,11 @@ const (
maxFlag byte = 250
)

const (
sizeUint64 = unsafe.Sizeof(uint64(0))
sizeFloat64 = unsafe.Sizeof(float64(0))
)

func preRealloc(b []byte, vals []types.Datum, comparable bool) []byte {
var size int
for i := range vals {
Expand Down Expand Up @@ -352,6 +358,215 @@ func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *type
return
}

// HashChunkColumns writes the encoded value of each row's column, which of index `colIdx`, to h.
func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.Chunk, tp *types.FieldType, colIdx int, buf []byte, isNull []bool) (err error) {
var b []byte
column := chk.Column(colIdx)
rows := chk.NumRows()
switch tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
i64s := column.Int64s()
for i, v := range i64s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = varintFlag
if mysql.HasUnsignedFlag(tp.Flag) && v < 0 {
buf[0] = uvarintFlag
}
b = column.GetRaw(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeFloat:
f32s := column.Float32s()
for i, f := range f32s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = floatFlag
d := float64(f)
b = (*[sizeFloat64]byte)(unsafe.Pointer(&d))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDouble:
f64s := column.Float64s()
for i, f := range f64s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = floatFlag
b = (*[sizeFloat64]byte)(unsafe.Pointer(&f))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = compactBytesFlag
b = column.GetBytes(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
ts := column.Times()
for i, t := range ts {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uintFlag
// Encoding timestamp need to consider timezone.
// If it's not in UTC, transform to UTC first.
if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC {
err = t.ConvertTimeZone(sc.TimeZone, time.UTC)
if err != nil {
return
}
}
var v uint64
v, err = t.ToPackedUint()
if err != nil {
return
}
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDuration:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = durationFlag
// duration may have negative value, so we cannot use String to encode directly.
b = column.GetRaw(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeNewDecimal:
ds := column.Decimals()
for i, d := range ds {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = decimalFlag
// If hash is true, we only consider the original value of this decimal and ignore it's precision.
b, err = d.ToHashKey()
if err != nil {
return
}
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeEnum:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uvarintFlag
v := uint64(column.GetEnum(i).ToNumber())
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeSet:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uvarintFlag
v := uint64(column.GetSet(i).ToNumber())
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeBit:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
// We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit.
buf[0] = uvarintFlag
v, err1 := types.BinaryLiteral(column.GetBytes(i)).ToInt(sc)
terror.Log(errors.Trace(err1))
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeJSON:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = jsonFlag
b = column.GetBytes(i)
}

// As the golang doc described, `Hash.Write` never returns an error..
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
default:
return errors.Errorf("unsupport column type for encode %d", tp.Tp)
}
return
}

// HashChunkRow writes the encoded values to w.
// If two rows are logically equal, it will generate the same bytes.
func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) {
Expand Down
Loading

0 comments on commit d29751c

Please sign in to comment.