From 0ef52acf82d75c19e9c29e420ba8a4fb4fb54f2a Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Thu, 12 Jul 2018 15:01:16 +0800 Subject: [PATCH] executor: support MAX/MIN in new evaluation framework partially (#6971) --- executor/aggfuncs/aggfuncs.go | 16 +- executor/aggfuncs/builder.go | 54 ++++-- executor/aggfuncs/func_max_min.go | 298 ++++++++++++++++++++++++++++++ executor/aggregate_test.go | 2 +- 4 files changed, 348 insertions(+), 22 deletions(-) create mode 100644 executor/aggfuncs/func_max_min.go diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 54d6832b9460f..73041aa3e07e9 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -25,6 +25,13 @@ import ( var ( // All the AggFunc implementations for "COUNT" are listed here. // All the AggFunc implementations for "SUM" are listed here. + // All the AggFunc implementations for "FIRSTROW" are listed here. + // All the AggFunc implementations for "MAX"/"MIN" are listed here. + _ AggFunc = (*maxMin4Int)(nil) + _ AggFunc = (*maxMin4Float32)(nil) + _ AggFunc = (*maxMin4Float64)(nil) + _ AggFunc = (*maxMin4Decimal)(nil) + // All the AggFunc implementations for "AVG" are listed here. _ AggFunc = (*avgOriginal4Decimal)(nil) _ AggFunc = (*avgOriginal4DistinctDecimal)(nil) @@ -34,15 +41,12 @@ var ( _ AggFunc = (*avgPartial4Float64)(nil) _ AggFunc = (*avgOriginal4DistinctFloat64)(nil) - // All the AggFunc implementations for "FIRSTROW" are listed here. - // All the AggFunc implementations for "MAX" are listed here. - // All the AggFunc implementations for "MIN" are listed here. // All the AggFunc implementations for "GROUP_CONCAT" are listed here. // All the AggFunc implementations for "BIT_OR" are listed here. + // All the AggFunc implementations for "BIT_XOR" are listed here. + // All the AggFunc implementations for "BIT_AND" are listed here. + // All the AggFunc implementations for "BIT_OR" are listed here. _ AggFunc = (*bitOrUint64)(nil) - -// All the AggFunc implementations for "BIT_XOR" are listed here. -// All the AggFunc implementations for "BIT_AND" are listed here. ) // PartialResult represents data structure to store the partial result for the diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 3325487749a64..2c01170d0e665 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -33,9 +33,9 @@ func Build(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { case ast.AggFuncFirstRow: return buildFirstRow(aggFuncDesc, ordinal) case ast.AggFuncMax: - return buildMax(aggFuncDesc, ordinal) + return buildMaxMin(aggFuncDesc, ordinal, true) case ast.AggFuncMin: - return buildMin(aggFuncDesc, ordinal) + return buildMaxMin(aggFuncDesc, ordinal, false) case ast.AggFuncGroupConcat: return buildGroupConcat(aggFuncDesc, ordinal) case ast.AggFuncBitOr: @@ -53,12 +53,12 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "SUM". +// buildSum builds the AggFunc implementation for function "SUM". func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "AVG". +// buildAvg builds the AggFunc implementation for function "AVG". func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ args: aggFuncDesc.Args, @@ -99,27 +99,51 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "FIRST_ROW". +// buildFirstRow builds the AggFunc implementation for function "FIRST_ROW". func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "MAX". -func buildMax(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - return nil -} +// buildMaxMin builds the AggFunc implementation for function "MAX" and "MIN". +func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) AggFunc { + base := baseMaxMinAggFunc{ + baseAggFunc: baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }, + isMax: isMax, + } -// buildCount builds the AggFunc implementation for function "MIN". -func buildMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + default: + switch evalType { + case types.ETInt: + if mysql.HasUnsignedFlag(fieldType.Flag) { + return &maxMin4Uint{base} + } + return &maxMin4Int{base} + case types.ETReal: + switch fieldType.Tp { + case mysql.TypeFloat: + return &maxMin4Float32{base} + case mysql.TypeDouble: + return &maxMin4Float64{base} + } + case types.ETDecimal: + return &maxMin4Decimal{base} + } + } return nil } -// buildCount builds the AggFunc implementation for function "GROUP_CONCAT". +// buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT". func buildGroupConcat(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "BIT_OR". +// buildBitOr builds the AggFunc implementation for function "BIT_OR". func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // BIT_OR doesn't need to handle the distinct property. switch aggFuncDesc.Args[0].GetType().EvalType() { @@ -133,12 +157,12 @@ func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "BIT_XOR". +// buildBitXor builds the AggFunc implementation for function "BIT_XOR". func buildBitXor(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } -// buildCount builds the AggFunc implementation for function "BIT_AND". +// buildBitAnd builds the AggFunc implementation for function "BIT_AND". func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil } diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go new file mode 100644 index 0000000000000..f5d6cbd79654e --- /dev/null +++ b/executor/aggfuncs/func_max_min.go @@ -0,0 +1,298 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" +) + +type partialResult4MaxMinInt struct { + val int64 + // isNull is used to indicates: + // 1. whether the partial result is the initialization value which should not be compared during evaluation; + // 2. whether all the values of arg are all null, if so, we should return null as the default value for MAX/MIN. + isNull bool +} + +type partialResult4MaxMinUint struct { + val uint64 + isNull bool +} + +type partialResult4MaxMinDecimal struct { + val types.MyDecimal + isNull bool +} + +type partialResult4MaxMinFloat32 struct { + val float32 + isNull bool +} + +type partialResult4MaxMinFloat64 struct { + val float64 + isNull bool +} + +type baseMaxMinAggFunc struct { + baseAggFunc + + isMax bool +} + +type maxMin4Int struct { + baseMaxMinAggFunc +} + +func (e *maxMin4Int) AllocPartialResult() PartialResult { + p := new(partialResult4MaxMinInt) + p.isNull = true + return PartialResult(p) +} + +func (e *maxMin4Int) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinInt)(pr) + p.val = 0 + p.isNull = true +} + +func (e *maxMin4Int) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinInt)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendInt64(e.ordinal, p.val) + return nil +} + +func (e *maxMin4Int) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4MaxMinInt)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + if p.isNull { + p.val = input + p.isNull = false + continue + } + if e.isMax && input > p.val || !e.isMax && input < p.val { + p.val = input + } + } + return nil +} + +type maxMin4Uint struct { + baseMaxMinAggFunc +} + +func (e *maxMin4Uint) AllocPartialResult() PartialResult { + p := new(partialResult4MaxMinUint) + p.isNull = true + return PartialResult(p) +} + +func (e *maxMin4Uint) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinUint)(pr) + p.val = 0 + p.isNull = true +} + +func (e *maxMin4Uint) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinUint)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendUint64(e.ordinal, p.val) + return nil +} + +func (e *maxMin4Uint) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4MaxMinUint)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + uintVal := uint64(input) + if p.isNull { + p.val = uintVal + p.isNull = false + continue + } + if e.isMax && uintVal > p.val || !e.isMax && uintVal < p.val { + p.val = uintVal + } + } + return nil +} + +// maxMin4Float32 gets a float32 input and returns a float32 result. +type maxMin4Float32 struct { + baseMaxMinAggFunc +} + +func (e *maxMin4Float32) AllocPartialResult() PartialResult { + p := new(partialResult4MaxMinFloat32) + p.isNull = true + return PartialResult(p) +} + +func (e *maxMin4Float32) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinFloat32)(pr) + p.val = 0 + p.isNull = true +} + +func (e *maxMin4Float32) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinFloat32)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendFloat32(e.ordinal, p.val) + return nil +} + +func (e *maxMin4Float32) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4MaxMinFloat32)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + f := float32(input) + if p.isNull { + p.val = f + p.isNull = false + continue + } + if e.isMax && f > p.val || !e.isMax && f < p.val { + p.val = f + } + } + return nil +} + +type maxMin4Float64 struct { + baseMaxMinAggFunc +} + +func (e *maxMin4Float64) AllocPartialResult() PartialResult { + p := new(partialResult4MaxMinFloat64) + p.isNull = true + return PartialResult(p) +} + +func (e *maxMin4Float64) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinFloat64)(pr) + p.val = 0 + p.isNull = true +} + +func (e *maxMin4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinFloat64)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendFloat64(e.ordinal, p.val) + return nil +} + +func (e *maxMin4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4MaxMinFloat64)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + if p.isNull { + p.val = input + p.isNull = false + continue + } + if e.isMax && input > p.val || !e.isMax && input < p.val { + p.val = input + } + } + return nil +} + +type maxMin4Decimal struct { + baseMaxMinAggFunc +} + +func (e *maxMin4Decimal) AllocPartialResult() PartialResult { + p := new(partialResult4MaxMinDecimal) + p.isNull = true + return PartialResult(p) +} + +func (e *maxMin4Decimal) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinDecimal)(pr) + p.isNull = true +} + +func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinDecimal)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendMyDecimal(e.ordinal, &p.val) + return nil +} + +func (e *maxMin4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4MaxMinDecimal)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDecimal(sctx, row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + if p.isNull { + p.val = *input + p.isNull = false + continue + } + cmp := input.Compare(&p.val) + if e.isMax && cmp == 1 || !e.isMax && cmp == -1 { + p.val = *input + } + } + return nil +} diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 915e7b697876b..73d13dcce44cd 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -537,7 +537,7 @@ func (s *testSuite) TestAggEliminator(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t(a int primary key, b int)") - tk.MustQuery("select min(a) from t").Check(testkit.Rows("")) + tk.MustQuery("select min(a), min(a) from t").Check(testkit.Rows(" ")) tk.MustExec("insert into t values(1, -1), (2, -2), (3, 1), (4, NULL)") tk.MustQuery("select max(a) from t").Check(testkit.Rows("4")) tk.MustQuery("select min(b) from t").Check(testkit.Rows("-2"))