diff --git a/executor/aggregate.go b/executor/aggregate.go index 61c55b5a3584d..b6595ba8f01ae 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -90,8 +90,9 @@ type HashAggFinalWorker struct { // AfFinalResult indicates aggregation functions final result. type AfFinalResult struct { - chk *chunk.Chunk - err error + chk *chunk.Chunk + err error + giveBackCh chan *chunk.Chunk } // HashAggExec deals with all the aggregate functions. @@ -150,7 +151,6 @@ type HashAggExec struct { finishCh chan struct{} finalOutputCh chan *AfFinalResult - finalInputCh chan *chunk.Chunk partialOutputChs []chan *HashAggIntermData inputCh chan *HashAggInput partialInputChs []chan *chunk.Chunk @@ -271,7 +271,6 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { partialConcurrency := sessionVars.HashAggPartialConcurrency e.isChildReturnEmpty = true e.finalOutputCh = make(chan *AfFinalResult, finalConcurrency) - e.finalInputCh = make(chan *chunk.Chunk, finalConcurrency) e.inputCh = make(chan *HashAggInput, partialConcurrency) e.finishCh = make(chan struct{}, 1) @@ -316,11 +315,12 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { groupSet: set.NewStringSet(), inputCh: e.partialOutputChs[i], outputCh: e.finalOutputCh, - finalResultHolderCh: e.finalInputCh, + finalResultHolderCh: make(chan *chunk.Chunk, 1), rowBuffer: make([]types.Datum, 0, e.Schema().Len()), mutableRow: chunk.MutRowFromTypes(retTypes(e)), groupKeys: make([][]byte, 0, 8), } + e.finalWorkers[i].finalResultHolderCh <- newFirstChunk(e) } } @@ -540,14 +540,14 @@ func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) { result.SetNumVirtualRows(result.NumRows() + 1) } if result.IsFull() { - w.outputCh <- &AfFinalResult{chk: result} + w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} result, finished = w.receiveFinalResultHolder() if finished { return } } } - w.outputCh <- &AfFinalResult{chk: result} + w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} } func (w *HashAggFinalWorker) receiveFinalResultHolder() (*chunk.Chunk, bool) { @@ -668,28 +668,26 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error if e.executed { return nil } - for !chk.IsFull() { - e.finalInputCh <- chk + for { result, ok := <-e.finalOutputCh - if !ok { // all finalWorkers exited + if !ok { e.executed = true - if chk.NumRows() > 0 { // but there are some data left - return nil - } if e.isChildReturnEmpty && e.defaultVal != nil { chk.Append(e.defaultVal, 0, 1) } - e.isChildReturnEmpty = false return nil } if result.err != nil { return result.err } + chk.SwapColumns(result.chk) + result.chk.Reset() + result.giveBackCh <- result.chk if chk.NumRows() > 0 { e.isChildReturnEmpty = false + return nil } } - return nil } // unparallelExec executes hash aggregation algorithm in single thread. diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 76dcb742c1796..e3896f4d66ddd 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -677,67 +677,6 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { } } -func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { - maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize - testCases := []struct { - totalRows int - aggFunc string - requiredRows []int - expectedRows []int - expectedRowsDS []int - gen func(valType *types.FieldType) interface{} - }{ - { - totalRows: maxChunkSize, - aggFunc: ast.AggFuncSum, - requiredRows: []int{1, 2, 3, 4, 5, 6, 7}, - expectedRows: []int{1, 2, 3, 4, 5, 6, 7}, - expectedRowsDS: []int{maxChunkSize, 0}, - gen: divGenerator(1), - }, - { - totalRows: maxChunkSize * 3, - aggFunc: ast.AggFuncAvg, - requiredRows: []int{1, 3}, - expectedRows: []int{1, 2}, - expectedRowsDS: []int{maxChunkSize, maxChunkSize, maxChunkSize, 0}, - gen: divGenerator(maxChunkSize), - }, - { - totalRows: maxChunkSize * 3, - aggFunc: ast.AggFuncAvg, - requiredRows: []int{maxChunkSize, maxChunkSize}, - expectedRows: []int{maxChunkSize, maxChunkSize / 2}, - expectedRowsDS: []int{maxChunkSize, maxChunkSize, maxChunkSize, 0}, - gen: divGenerator(2), - }, - } - - for _, hasDistinct := range []bool{false, true} { - for _, testCase := range testCases { - sctx := defaultCtx() - ctx := context.Background() - ds := newRequiredRowsDataSourceWithGenerator(sctx, testCase.totalRows, testCase.expectedRowsDS, testCase.gen) - childCols := ds.Schema().Columns - schema := expression.NewSchema(childCols...) - groupBy := []expression.Expression{childCols[1]} - aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) - c.Assert(err, IsNil) - aggFuncs := []*aggregation.AggFuncDesc{aggFunc} - exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) - c.Assert(exec.Open(ctx), IsNil) - chk := newFirstChunk(exec) - for i := range testCase.requiredRows { - chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chk), IsNil) - c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) - } - c.Assert(exec.Close(), IsNil) - c.Assert(ds.checkNumNextCalled(), IsNil) - } - } -} - func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) { justReturn1 := func(valType *types.FieldType) interface{} { switch valType.Tp {