Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: vectorize hash calculation in hashJoin (#12048) #12076

Merged
merged 2 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"hash"
"hash/fnv"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -47,8 +48,28 @@ const (
type hashContext struct {
allTypes []*types.FieldType
keyColIdx []int
h hash.Hash64
buf []byte
hashVals []hash.Hash64
hasNull []bool
}

func (hc *hashContext) initHash(rows int) {
if hc.buf == nil {
hc.buf = make([]byte, 1)
}

if len(hc.hashVals) < rows {
XuHuaiyu marked this conversation as resolved.
Show resolved Hide resolved
hc.hasNull = make([]bool, rows)
hc.hashVals = make([]hash.Hash64, rows)
for i := 0; i < rows; i++ {
hc.hashVals[i] = fnv.New64()
}
} else {
for i := 0; i < rows; i++ {
XuHuaiyu marked this conversation as resolved.
Show resolved Hide resolved
hc.hasNull[i] = false
hc.hashVals[i].Reset()
}
}
}

// hashRowContainer handles the rows and the hash map of a table.
Expand Down Expand Up @@ -133,22 +154,24 @@ func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx
// value of hash table: RowPtr of the corresponded row
func (c *hashRowContainer) PutChunk(chk *chunk.Chunk) error {
chkIdx := uint32(c.records.NumChunks())
c.records.Add(chk)
var (
hasNull bool
err error
key uint64
)
numRows := chk.NumRows()
for j := 0; j < numRows; j++ {
hasNull, key, err = c.getJoinKeyFromChkRow(c.sc, chk.GetRow(j), c.hCtx)

c.records.Add(chk)
c.hCtx.initHash(numRows)

hCtx := c.hCtx
for _, colIdx := range c.hCtx.keyColIdx {
err := codec.HashChunkColumns(c.sc, hCtx.hashVals, chk, hCtx.allTypes[colIdx], colIdx, hCtx.buf, hCtx.hasNull)
if err != nil {
return errors.Trace(err)
}
if hasNull {
}
for i := 0; i < numRows; i++ {
if c.hCtx.hasNull[i] {
continue
}
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)}
key := c.hCtx.hashVals[i].Sum64()
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)}
c.hashTable.Put(key, rowPtr)
}
return nil
Expand All @@ -161,9 +184,9 @@ func (*hashRowContainer) getJoinKeyFromChkRow(sc *stmtctx.StatementContext, row
return true, 0, nil
}
}
hCtx.h.Reset()
err = codec.HashChunkRow(sc, hCtx.h, row, hCtx.allTypes, hCtx.keyColIdx, hCtx.buf)
return false, hCtx.h.Sum64(), err
hCtx.initHash(1)
err = codec.HashChunkRow(sc, hCtx.hashVals[0], row, hCtx.allTypes, hCtx.keyColIdx, hCtx.buf)
return false, hCtx.hashVals[0].Sum64(), err
}

func (c hashRowContainer) Len() int {
Expand Down
5 changes: 0 additions & 5 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package executor
import (
"context"
"fmt"
"hash/fnv"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -334,8 +333,6 @@ func (e *HashJoinExec) runJoinWorker(workerID uint, outerKeyColIdx []int) {
hCtx := &hashContext{
allTypes: retTypes(e.outerExec),
keyColIdx: outerKeyColIdx,
h: fnv.New64(),
buf: make([]byte, 1),
}
for ok := true; ok; {
if e.finished.Load().(bool) {
Expand Down Expand Up @@ -506,8 +503,6 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk)
hCtx := &hashContext{
allTypes: allTypes,
keyColIdx: innerKeyColIdx,
h: fnv.New64(),
buf: make([]byte, 1),
}
initList := chunk.NewList(allTypes, e.initCap, e.maxChunkSize)
e.rowContainer = newHashRowContainer(e.ctx, int(e.innerEstCount), hCtx, initList)
Expand Down
215 changes: 215 additions & 0 deletions util/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package codec
import (
"bytes"
"encoding/binary"
"hash"
"io"
"time"
"unsafe"
Expand Down Expand Up @@ -45,6 +46,11 @@ const (
maxFlag byte = 250
)

const (
sizeUint64 = unsafe.Sizeof(uint64(0))
sizeFloat64 = unsafe.Sizeof(float64(0))
)

func preRealloc(b []byte, vals []types.Datum, comparable bool) []byte {
var size int
for i := range vals {
Expand Down Expand Up @@ -352,6 +358,215 @@ func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *type
return
}

// HashChunkColumns writes the encoded value of each row's column, which of index `colIdx`, to h.
func HashChunkColumns(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk.Chunk, tp *types.FieldType, colIdx int, buf []byte, isNull []bool) (err error) {
var b []byte
column := chk.Column(colIdx)
rows := chk.NumRows()
switch tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
i64s := column.Int64s()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about:

for i := 0; i < rows; i++ {
    if column.IsNull(i) {
        h[i].Write(NilFlag)
        continue
    }
    h[i].Write(column.GetRaw(i))
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean no need to write the flag byte?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore the flag byte will generate a different hash value from that generated by HashChunkRow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Maybe we can optimize this in another PR:

  1. optimize the functions we used to calculate the hash value.
  2. vectorize the way to calculate hash values for the outer table when performs a hash join.

for i, v := range i64s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = varintFlag
if mysql.HasUnsignedFlag(tp.Flag) && v < 0 {
buf[0] = uvarintFlag
}
b = column.GetRaw(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
}
case mysql.TypeFloat:
f32s := column.Float32s()
for i, f := range f32s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = floatFlag
d := float64(f)
b = (*[sizeFloat64]byte)(unsafe.Pointer(&d))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDouble:
f64s := column.Float64s()
for i, f := range f64s {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = floatFlag
b = (*[sizeFloat64]byte)(unsafe.Pointer(&f))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = compactBytesFlag
b = column.GetBytes(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
ts := column.Times()
for i, t := range ts {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uintFlag
// Encoding timestamp need to consider timezone.
// If it's not in UTC, transform to UTC first.
if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC {
err = t.ConvertTimeZone(sc.TimeZone, time.UTC)
if err != nil {
return
}
}
var v uint64
v, err = t.ToPackedUint()
if err != nil {
return
}
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeDuration:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = durationFlag
// duration may have negative value, so we cannot use String to encode directly.
b = column.GetRaw(i)
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeNewDecimal:
ds := column.Decimals()
for i, d := range ds {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = decimalFlag
// If hash is true, we only consider the original value of this decimal and ignore it's precision.
b, err = d.ToHashKey()
if err != nil {
return
}
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeEnum:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uvarintFlag
v := uint64(column.GetEnum(i).ToNumber())
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeSet:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = uvarintFlag
v := uint64(column.GetSet(i).ToNumber())
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeBit:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
// We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit.
buf[0] = uvarintFlag
v, err1 := types.BinaryLiteral(column.GetBytes(i)).ToInt(sc)
terror.Log(errors.Trace(err1))
b = (*[sizeUint64]byte)(unsafe.Pointer(&v))[:]
}

// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
case mysql.TypeJSON:
for i := 0; i < rows; i++ {
if column.IsNull(i) {
buf[0], b = NilFlag, nil
isNull[i] = true
} else {
buf[0] = jsonFlag
b = column.GetBytes(i)
}

// As the golang doc described, `Hash.Write` never returns an error..
// See https://golang.org/pkg/hash/#Hash
_, _ = h[i].Write(buf)
_, _ = h[i].Write(b)
}
default:
return errors.Errorf("unsupport column type for encode %d", tp.Tp)
}
return
}

// HashChunkRow writes the encoded values to w.
// If two rows are logically equal, it will generate the same bytes.
func HashChunkRow(sc *stmtctx.StatementContext, w io.Writer, row chunk.Row, allTypes []*types.FieldType, colIdx []int, buf []byte) (err error) {
Expand Down
Loading