diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 99bb6ceec9103..7ff162239ebb9 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -916,18 +916,20 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor) } e := &HashJoinExec{ baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, 5, innerExec, outerExec), - probeSideTupleFetcher: probeSideTupleFetcher{ + hashJoinCtx: &hashJoinCtx{ + joinType: testCase.joinType, // 0 for InnerJoin, 1 for LeftOutersJoin, 2 for RightOuterJoin + isOuterJoin: false, + useOuterToBuild: testCase.useOuterToBuild, + }, + probeSideTupleFetcher: &probeSideTupleFetcher{ probeSideExec: outerExec, }, probeWorkers: make([]probeWorker, testCase.concurrency), concurrency: uint(testCase.concurrency), - joinType: testCase.joinType, // 0 for InnerJoin, 1 for LeftOutersJoin, 2 for RightOuterJoin - isOuterJoin: false, buildKeys: joinKeys, probeKeys: probeKeys, buildSideExec: innerExec, buildSideEstCount: float64(testCase.rows), - useOuterToBuild: testCase.useOuterToBuild, } childrenUsedSchema := markChildrenUsedCols(e.Schema(), e.children[0].Schema(), e.children[1].Schema()) diff --git a/executor/builder.go b/executor/builder.go index dfd44d549eff1..ea718aa9df6dc 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1412,11 +1412,14 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } e := &HashJoinExec{ - baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), - concurrency: v.Concurrency, - joinType: v.JoinType, - isOuterJoin: v.JoinType.IsOuterJoin(), - useOuterToBuild: v.UseOuterToBuild, + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), + probeSideTupleFetcher: &probeSideTupleFetcher{}, + hashJoinCtx: &hashJoinCtx{ + isOuterJoin: v.JoinType.IsOuterJoin(), + useOuterToBuild: v.UseOuterToBuild, + joinType: v.JoinType, + }, + concurrency: v.Concurrency, } defaultValues := v.DefaultValues lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec) diff --git a/executor/join.go b/executor/join.go index 95ecee42c02d4..072f0106d50ef 100644 --- a/executor/join.go +++ b/executor/join.go @@ -46,8 +46,22 @@ var ( _ Executor = &NestedLoopApplyExec{} ) +type hashJoinCtx struct { + joinResultCh chan *hashjoinWorkerResult + // closeCh add a lock for closing executor. + closeCh chan struct{} + finished atomic.Bool + useOuterToBuild bool + isOuterJoin bool + buildFinished chan error + rowContainer *hashRowContainer + joinType plannercore.JoinType +} + // probeSideTupleFetcher reads tuples from probeSideExec and send them to probeWorkers. type probeSideTupleFetcher struct { + *hashJoinCtx + probeSideExec Executor probeChkResourceCh chan *probeChkResource probeResultChs []chan *chunk.Chunk @@ -73,7 +87,8 @@ type probeWorker struct { type HashJoinExec struct { baseExecutor - probeSideTupleFetcher + probeSideTupleFetcher *probeSideTupleFetcher + *hashJoinCtx probeWorkers []probeWorker buildSideExec Executor buildSideEstCount float64 @@ -87,29 +102,19 @@ type HashJoinExec struct { buildTypes []*types.FieldType // concurrency is the number of partition, build and join workers. - concurrency uint - rowContainer *hashRowContainer - buildFinished chan error + concurrency uint - // closeCh add a lock for closing executor. - closeCh chan struct{} - worker util.WaitGroupWrapper - waiter util.WaitGroupWrapper - joinType plannercore.JoinType + worker util.WaitGroupWrapper + waiter util.WaitGroupWrapper joinChkResourceCh []chan *chunk.Chunk - joinResultCh chan *hashjoinWorkerResult memTracker *memory.Tracker // track memory usage. diskTracker *disk.Tracker // track disk usage. outerMatchedStatus []*bitmap.ConcurrentBitmap - useOuterToBuild bool - prepared bool - isOuterJoin bool - - finished atomic.Bool + prepared bool stats *hashJoinRuntimeStats } @@ -212,32 +217,32 @@ func (e *HashJoinExec) Open(ctx context.Context) error { // fetchProbeSideChunks get chunks from fetches chunks from the big table in a background goroutine // and sends the chunks to multiple channels which will be read by multiple join workers. -func (e *HashJoinExec) fetchProbeSideChunks(ctx context.Context) { +func (fetcher *probeSideTupleFetcher) fetchProbeSideChunks(ctx context.Context, maxChunkSize int) { hasWaitedForBuild := false for { - if e.finished.Load() { + if fetcher.finished.Load() { return } var probeSideResource *probeChkResource var ok bool select { - case <-e.closeCh: + case <-fetcher.closeCh: return - case probeSideResource, ok = <-e.probeSideTupleFetcher.probeChkResourceCh: + case probeSideResource, ok = <-fetcher.probeChkResourceCh: if !ok { return } } probeSideResult := probeSideResource.chk - if e.isOuterJoin { - required := int(atomic.LoadInt64(&e.probeSideTupleFetcher.requiredRows)) - probeSideResult.SetRequiredRows(required, e.maxChunkSize) + if fetcher.isOuterJoin { + required := int(atomic.LoadInt64(&fetcher.requiredRows)) + probeSideResult.SetRequiredRows(required, maxChunkSize) } - err := Next(ctx, e.probeSideTupleFetcher.probeSideExec, probeSideResult) + err := Next(ctx, fetcher.probeSideExec, probeSideResult) failpoint.Inject("ConsumeRandomPanic", nil) if err != nil { - e.joinResultCh <- &hashjoinWorkerResult{ + fetcher.joinResultCh <- &hashjoinWorkerResult{ err: err, } return @@ -248,23 +253,18 @@ func (e *HashJoinExec) fetchProbeSideChunks(ctx context.Context) { probeSideResult.Reset() } }) - if probeSideResult.NumRows() == 0 && !e.useOuterToBuild { - e.finished.Store(true) + if probeSideResult.NumRows() == 0 && !fetcher.useOuterToBuild { + fetcher.finished.Store(true) } - emptyBuild, buildErr := e.wait4BuildSide() + emptyBuild, buildErr := fetcher.wait4BuildSide() if buildErr != nil { - e.joinResultCh <- &hashjoinWorkerResult{ + fetcher.joinResultCh <- &hashjoinWorkerResult{ err: buildErr, } return } else if emptyBuild { return } - // after building is finished. the hash null bucket slice is allocated and determined. - // copy it for multi probe worker. - for _, w := range e.probeWorkers { - w.rowContainerForProbe.hashNANullBucket = e.rowContainer.hashNANullBucket - } hasWaitedForBuild = true } @@ -276,16 +276,16 @@ func (e *HashJoinExec) fetchProbeSideChunks(ctx context.Context) { } } -func (e *HashJoinExec) wait4BuildSide() (emptyBuild bool, err error) { +func (fetcher *probeSideTupleFetcher) wait4BuildSide() (emptyBuild bool, err error) { select { - case <-e.closeCh: + case <-fetcher.closeCh: return true, nil - case err := <-e.buildFinished: + case err := <-fetcher.buildFinished: if err != nil { return false, err } } - if e.rowContainer.Len() == uint64(0) && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { + if fetcher.rowContainer.Len() == uint64(0) && (fetcher.joinType == plannercore.InnerJoin || fetcher.joinType == plannercore.SemiJoin) { return true, nil } return false, nil @@ -329,6 +329,11 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu } func (e *HashJoinExec) initializeForProbe() { + // e.joinResultCh is for transmitting the join result chunks to the main + // thread. + e.joinResultCh = make(chan *hashjoinWorkerResult, e.concurrency+1) + + e.probeSideTupleFetcher.hashJoinCtx = e.hashJoinCtx // e.probeSideTupleFetcher.probeResultChs is for transmitting the chunks which store the data of // probeSideExec, it'll be written by probe side worker goroutine, and read by join // workers. @@ -354,18 +359,14 @@ func (e *HashJoinExec) initializeForProbe() { e.joinChkResourceCh[i] = make(chan *chunk.Chunk, 1) e.joinChkResourceCh[i] <- newFirstChunk(e) } - - // e.joinResultCh is for transmitting the join result chunks to the main - // thread. - e.joinResultCh = make(chan *hashjoinWorkerResult, e.concurrency+1) } func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { e.initializeForProbe() e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() - e.fetchProbeSideChunks(ctx) - }, e.handleProbeSideFetcherPanic) + e.probeSideTupleFetcher.fetchProbeSideChunks(ctx, e.maxChunkSize) + }, e.probeSideTupleFetcher.handleProbeSideFetcherPanic) probeKeyColIdx := make([]int, len(e.probeKeys)) probeNAKeColIdx := make([]int, len(e.probeNAKeys)) @@ -375,7 +376,6 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { for i := range e.probeNAKeys { probeNAKeColIdx[i] = e.probeNAKeys[i].Index } - for i := uint(0); i < e.concurrency; i++ { workID := i e.worker.RunWithRecover(func() { @@ -386,12 +386,12 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) } -func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) { - for i := range e.probeSideTupleFetcher.probeResultChs { - close(e.probeSideTupleFetcher.probeResultChs[i]) +func (fetcher *probeSideTupleFetcher) handleProbeSideFetcherPanic(r interface{}) { + for i := range fetcher.probeResultChs { + close(fetcher.probeResultChs[i]) } if r != nil { - e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} + fetcher.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } }