Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: refine StreamAggExec when child is empty #7002

Merged
merged 13 commits into from
Jul 11, 2018
Merged
54 changes: 28 additions & 26 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ type HashAggExec struct {
partialWorkers []HashAggPartialWorker
finalWorkers []HashAggFinalWorker
defaultVal *chunk.Chunk
// isChildExecReturnEmpty indicates whether the child executor only returns an empty input.
isChildExecReturnEmpty bool
// isChildReturnEmpty indicates whether the child executor only returns an empty input.
isChildReturnEmpty bool
}

// HashAggInput indicates the input of hash agg exec.
Expand Down Expand Up @@ -253,7 +253,7 @@ func (e *HashAggExec) initForParallelExec() {
sessionVars := e.ctx.GetSessionVars()
finalConcurrency := sessionVars.HashAggFinalConcurrency
partialConcurrency := sessionVars.HashAggPartialConcurrency
e.isChildExecReturnEmpty = true
e.isChildReturnEmpty = true
e.finalOutputCh = make(chan *AfFinalResult, finalConcurrency)
e.inputCh = make(chan *HashAggInput, partialConcurrency)
e.finishCh = make(chan struct{}, 1)
Expand Down Expand Up @@ -621,13 +621,13 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error
if result != nil {
return errors.Trace(result.err)
}
if e.isChildExecReturnEmpty && e.defaultVal != nil {
if e.isChildReturnEmpty && e.defaultVal != nil {
chk.Append(e.defaultVal, 0, 1)
}
e.isChildExecReturnEmpty = false
e.isChildReturnEmpty = false
return nil
}
e.isChildExecReturnEmpty = false
e.isChildReturnEmpty = false
chk.SwapColumns(result.chk)
// Put result.chk back to the corresponded final worker's finalResultHolderCh.
result.giveBackCh <- result.chk
Expand Down Expand Up @@ -744,20 +744,23 @@ func (e *HashAggExec) getContexts(groupKey []byte) []*aggregation.AggEvaluateCon
type StreamAggExec struct {
baseExecutor

executed bool
hasData bool
StmtCtx *stmtctx.StatementContext
AggFuncs []aggregation.Aggregation
aggCtxs []*aggregation.AggEvaluateContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum
executed bool
// isChildReturnEmpty indicates whether the child executor only returns an empty input.
isChildReturnEmpty bool
StmtCtx *stmtctx.StatementContext
AggFuncs []aggregation.Aggregation
aggCtxs []*aggregation.AggEvaluateContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum

inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
mutableRow chunk.MutRow
rowBuffer []types.Datum

defaultVal *chunk.Chunk

// for the new execution framework of aggregate functions
newAggFuncs []aggfuncs.AggFunc
partialResults []aggfuncs.PartialResult
Expand All @@ -771,7 +774,7 @@ func (e *StreamAggExec) Open(ctx context.Context) error {
}

e.executed = false
e.hasData = false
e.isChildReturnEmpty = true
e.inputIter = chunk.NewIterator4Chunk(e.childrenResults[0])
e.inputRow = e.inputIter.End()
e.mutableRow = chunk.MutRowFromTypes(e.retTypes())
Expand Down Expand Up @@ -860,38 +863,37 @@ func (e *StreamAggExec) consumeGroupRows() error {
return nil
}

func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) error {
func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) {
if e.inputRow != e.inputIter.End() {
return nil
}

// Before fetching a new batch of input, we should consume the last group.
if e.newAggFuncs != nil {
err := e.consumeGroupRows()
err = e.consumeGroupRows()
if err != nil {
return errors.Trace(err)
}
}

err := e.children[0].Next(ctx, e.childrenResults[0])
err = e.children[0].Next(ctx, e.childrenResults[0])
if err != nil {
return errors.Trace(err)
}

// No more data.
if e.childrenResults[0].NumRows() == 0 {
if e.hasData || len(e.GroupByItems) == 0 {
err := e.appendResult2Chunk(chk)
if err != nil {
return errors.Trace(err)
}
if !e.isChildReturnEmpty {
err = e.appendResult2Chunk(chk)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't check the value of err here? If you have any particular reason to do so, is it better you add it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

err is checked in line 892

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it.

} else if e.defaultVal != nil {
chk.Append(e.defaultVal, 0, 1)
}
e.executed = true
return nil
return errors.Trace(err)
}

// Reach here, "e.childrenResults[0].NumRows() > 0" is guaranteed.
e.isChildReturnEmpty = false
e.inputRow = e.inputIter.Begin()
e.hasData = true
return nil
}

Expand Down
10 changes: 10 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,16 @@ func (s *testSuite) TestAggregation(c *C) {
tk.MustExec(`insert into t values (7, '{"i": -1, "n": "n7"}')`)
tk.MustQuery("select sum(tags->'$.i') from t").Check(testkit.Rows("14"))

// test agg with empty input
result = tk.MustQuery("select id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t where null")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil>"))
tk.MustExec("truncate table t")
tk.MustExec("create table s(id int)")
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("<nil> 0 <nil> <nil> 0 18446744073709551615 0 <nil> <nil> <nil>"))
tk.MustExec(`insert into t values (1, '{"i": 1, "n": "n1"}')`)
result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id")
result.Check(testkit.Rows("1 1 95 95.0000 95 95 95 95 95 95"))
tk.MustExec("set @@tidb_hash_join_concurrency=5")
}

Expand Down
16 changes: 12 additions & 4 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,8 @@ func (b *executorBuilder) buildHashAgg(v *plan.PhysicalHashAgg) Executor {
}
e.AggFuncs = append(e.AggFuncs, aggDesc.GetAggFunc())
if e.defaultVal != nil {
value, existsDefaultValue := aggDesc.CalculateDefaultValue(e.ctx, e.children[0].Schema())
if existsDefaultValue {
e.defaultVal.AppendDatum(i, &value)
}
value := aggDesc.GetDefaultValue()
e.defaultVal.AppendDatum(i, &value)
}
}

Expand All @@ -970,9 +968,19 @@ func (b *executorBuilder) buildStreamAgg(v *plan.PhysicalStreamAgg) Executor {
AggFuncs: make([]aggregation.Aggregation, 0, len(v.AggFuncs)),
GroupByItems: v.GroupByItems,
}
if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) {
e.defaultVal = nil
} else {
e.defaultVal = chunk.NewChunkWithCapacity(e.retTypes(), 1)
}
newAggFuncs := make([]aggfuncs.AggFunc, 0, len(v.AggFuncs))
for i, aggDesc := range v.AggFuncs {
e.AggFuncs = append(e.AggFuncs, aggDesc.GetAggFunc())
if e.defaultVal != nil {
value := aggDesc.GetDefaultValue()
e.defaultVal.AppendDatum(i, &value)
}
// For new aggregate evaluation framework.
newAggFunc := aggfuncs.Build(aggDesc, i)
if newAggFunc != nil {
newAggFuncs = append(newAggFuncs, newAggFunc)
Expand Down
104 changes: 79 additions & 25 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,39 +110,86 @@ func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) {
}
}

// CalculateDefaultValue gets the default value when the aggregation function's input is null.
// The input stands for the schema of Aggregation's child. If the function can't produce a default value, the second
// EvalNullValueInOuterJoin gets the null value when the aggregation is upon an outer join,
// and the aggregation function's input is null.
// If there is no matching row for the inner table of an outer join,
// an aggregation function only involves constant and/or columns belongs to the inner table
// will be set to the null value.
// The input stands for the schema of Aggregation's child. If the function can't produce a null value, the second
// return value will be false.
// e.g.
// Table t with only one row:
// +-------+---------+---------+
// | Table | Field | Type |
// +-------+---------+---------+
// | t | a | int(11) |
// +-------+---------+---------+
// +------+
// | a |
// +------+
// | 1 |
// +------+
//
// According to MySQL, DefaultValue of the aggregation function can be tested as the following sql:
// Table s which is empty:
// +-------+---------+---------+
// | Table | Field | Type |
// +-------+---------+---------+
// | s | a | int(11) |
// +-------+---------+---------+
//
// mysql> CREATE TABLE `t` (
// -> `a` int(11) DEFAULT NULL,
// -> `b` int(11) DEFAULT NULL
// -> );
// mysql>
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | a | avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | NULL | NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// 1 row in set (0.01 sec)
func (a *AggFuncDesc) CalculateDefaultValue(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
// Query: `select t.a as `t.a`, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), s.a as `s.a`, avg(95) from t left join s on t.a = s.a;`
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
// | t.a | count(95) | sum(95) | avg(95) | bit_or(95) | bit_and(95) | bit_or(95) | max(95) | min(95) | s.a | avg(s.a) |
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
// | 1 | 1 | 95 | 95.0000 | 95 | 95 | 95 | 95 | 95 | NULL | NULL |
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
switch a.Name {
case ast.AggFuncCount:
return a.calculateDefaultValue4Count(ctx, schema)
return a.evalNullValueInOuterJoin4Count(ctx, schema)
case ast.AggFuncSum, ast.AggFuncMax, ast.AggFuncMin,
ast.AggFuncFirstRow, ast.AggFuncAvg, ast.AggFuncGroupConcat:
return a.calculateDefaultValue4Sum(ctx, schema)
ast.AggFuncFirstRow:
return a.evalNullValueInOuterJoin4Sum(ctx, schema)
case ast.AggFuncAvg, ast.AggFuncGroupConcat:
return types.Datum{}, false
case ast.AggFuncBitAnd:
return a.calculateDefaultValue4BitAnd(ctx, schema)
return a.evalNullValueInOuterJoin4BitAnd(ctx, schema)
case ast.AggFuncBitOr, ast.AggFuncBitXor:
return a.calculateDefaultValue4BitOr(ctx, schema)
return a.evalNullValueInOuterJoin4BitOr(ctx, schema)
default:
panic("unsupported agg function")
}
}

// GetDefaultValue gets the default value when the aggregation function's input is null.
// According to MySQL, default values of the aggregation function are listed as follows:
// e.g.
// Table t which is empty:
// +-------+---------+---------+
// | Table | Field | Type |
// +-------+---------+---------+
// | t | a | int(11) |
// +-------+---------+---------+
//
// Query: `select a, avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a) from t;`
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | a | avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | NULL | NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
func (a *AggFuncDesc) GetDefaultValue() (v types.Datum) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this function be moved to the Aggregation interface, it seems like belonging to the execution layer.

switch a.Name {
case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor:
v = types.NewIntDatum(0)
case ast.AggFuncFirstRow, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax,
ast.AggFuncMin, ast.AggFuncGroupConcat:
v = types.Datum{}
case ast.AggFuncBitAnd:
v = types.NewUintDatum(uint64(math.MaxUint64))
}
return v
}

// GetAggFunc gets an evaluator according to the aggregation function signature.
func (a *AggFuncDesc) GetAggFunc() Aggregation {
aggFunc := aggFunction{AggFuncDesc: a}
Expand Down Expand Up @@ -246,11 +293,18 @@ func (a *AggFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
// TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *AggFuncDesc) calculateDefaultValue4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
return types.NewDatum(0), true
func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
for _, arg := range a.Args {
result := expression.EvaluateExprWithNull(ctx, schema, arg)
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
return types.Datum{}, ok
}
}
return types.NewDatum(1), true
}

func (a *AggFuncDesc) calculateDefaultValue4Sum(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand All @@ -259,7 +313,7 @@ func (a *AggFuncDesc) calculateDefaultValue4Sum(ctx sessionctx.Context, schema *
return con.Value, true
}

func (a *AggFuncDesc) calculateDefaultValue4BitAnd(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand All @@ -268,7 +322,7 @@ func (a *AggFuncDesc) calculateDefaultValue4BitAnd(ctx sessionctx.Context, schem
return con.Value, true
}

func (a *AggFuncDesc) calculateDefaultValue4BitOr(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand Down
20 changes: 20 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2744,6 +2744,26 @@ func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) {
tk.MustExec("insert into t values(1.123456), (1.123456)")
result := tk.MustQuery("select avg(a) from t")
result.Check(testkit.Rows("1.1234560000"))

tk.MustExec("use test")
tk.MustExec("drop table t")
tk.MustExec("CREATE TABLE `t` ( `a` int, KEY `idx_a` (`a`))")
result = tk.MustQuery("select avg(a) from t")
result.Check(testkit.Rows("<nil>"))
result = tk.MustQuery("select max(a), min(a) from t")
result.Check(testkit.Rows("<nil> <nil>"))
result = tk.MustQuery("select distinct a from t")
result.Check(testkit.Rows())
result = tk.MustQuery("select sum(a) from t")
result.Check(testkit.Rows("<nil>"))
result = tk.MustQuery("select count(a) from t")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select bit_or(a) from t")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select bit_xor(a) from t")
result.Check(testkit.Rows("0"))
result = tk.MustQuery("select bit_and(a) from t")
result.Check(testkit.Rows("18446744073709551615"))
}

func (s *testIntegrationSuite) TestAggregationBuiltinBitOr(c *C) {
Expand Down
2 changes: 1 addition & 1 deletion plan/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncD
func (a *aggregationOptimizer) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) {
defaultValues := make([]types.Datum, 0, agg.Schema().Len())
for _, aggFunc := range agg.AggFuncs {
value, existsDefaultValue := aggFunc.CalculateDefaultValue(agg.ctx, agg.children[0].Schema())
value, existsDefaultValue := aggFunc.EvalNullValueInOuterJoin(agg.ctx, agg.children[0].Schema())
if !existsDefaultValue {
return nil, false
}
Expand Down