From d26b07f4a05334bb37a5d22d79171ac47e6c1639 Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Mon, 2 Jul 2018 11:13:53 +0800 Subject: [PATCH 1/4] aggfuncs: partially implement "AVG" --- executor/aggfuncs/aggfuncs.go | 15 +- executor/aggfuncs/builder.go | 44 +++++ executor/aggfuncs/func_avg.go | 291 ++++++++++++++++++++++++++++++++++ 3 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 executor/aggfuncs/func_avg.go diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 7e82e2c11ff75..289501dac13a9 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -23,9 +23,18 @@ import ( // All the AggFunc implementations are listed here for navigation. var ( -// All the AggFunc implementations for "COUNT" are listed here. -// All the AggFunc implementations for "SUM" are listed here. -// All the AggFunc implementations for "AVG" are listed here. + // All the AggFunc implementations for "COUNT" are listed here. + // All the AggFunc implementations for "SUM" are listed here. + // All the AggFunc implementations for "AVG" are listed here. + _ AggFunc = (*avgOriginal4Decimal)(nil) + _ AggFunc = (*avgPartial4Decimal)(nil) + + _ AggFunc = (*avgOriginal4Float64)(nil) + _ AggFunc = (*avgPartial4Float64)(nil) + + _ AggFunc = (*avgOriginal4Float32)(nil) + _ AggFunc = (*avgPartial4Float32)(nil) + // All the AggFunc implementations for "FIRSTROW" are listed here. // All the AggFunc implementations for "MAX" are listed here. // All the AggFunc implementations for "MIN" are listed here. diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 09fab4126eeee..9a87c2ba0e873 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -16,6 +16,7 @@ package aggfuncs import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/mysql" ) // Build is used to build a specific AggFunc implementation according to the @@ -58,6 +59,49 @@ func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // buildCount builds the AggFunc implementation for function "AVG". func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + switch aggFuncDesc.Mode { + // Build avg functions which consume the original data and remove the + // duplicated input of the same group. + case aggregation.DedupMode: + return nil // not implemented yet. + + // Build avg functions which consume the original data and update their + // partial results. + case aggregation.CompleteMode, aggregation.Partial1Mode: + switch aggFuncDesc.Args[0].GetType().Tp { + case mysql.TypeNewDecimal: + if aggFuncDesc.HasDistinct { + return nil // not implemented yet. + } + return &avgOriginal4Decimal{baseAvgDecimal{base}} + case mysql.TypeFloat: + if aggFuncDesc.HasDistinct { + return nil // not implemented yet. + } + return &avgOriginal4Float32{baseAvgFloat32{base}} + case mysql.TypeDouble: + if aggFuncDesc.HasDistinct { + return nil // not implemented yet. + } + return &avgOriginal4Float64{baseAvgFloat64{base}} + } + + // Build avg functions which consume the partial result of other avg + // functions and update their partial results. + case aggregation.Partial2Mode, aggregation.FinalMode: + switch aggFuncDesc.Args[1].GetType().Tp { + case mysql.TypeNewDecimal: + return &avgPartial4Decimal{baseAvgDecimal{base}} + case mysql.TypeFloat: + return &avgPartial4Float32{baseAvgFloat32{base}} + case mysql.TypeDouble: + return &avgPartial4Float64{baseAvgFloat64{base}} + } + } return nil } diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go new file mode 100644 index 0000000000000..b428cec770ee7 --- /dev/null +++ b/executor/aggfuncs/func_avg.go @@ -0,0 +1,291 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" +) + +// All the AggFunc implementations for "AVG" are listed here. +var ( + _ AggFunc = (*avgOriginal4Decimal)(nil) + _ AggFunc = (*avgPartial4Decimal)(nil) + + _ AggFunc = (*avgOriginal4Float64)(nil) + _ AggFunc = (*avgPartial4Float64)(nil) + + _ AggFunc = (*avgOriginal4Float32)(nil) + _ AggFunc = (*avgPartial4Float32)(nil) +) + +// All the following avg function implementations return the decimal result, +// which store the partial results in "partialResult4AvgDecimal". +// +// "baseAvgDecimal" is wrapped by: +// - "avgOriginal4Decimal" +// - "avgPartial4Decimal" +type baseAvgDecimal struct { + baseAggFunc +} + +type partialResult4AvgDecimal struct { + sum types.MyDecimal + count int64 +} + +func (e *baseAvgDecimal) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4AvgDecimal{}) +} + +func (e *baseAvgDecimal) ResetPartialResult(pr PartialResult) { + p := (*partialResult4AvgDecimal)(pr) + p.sum = *types.NewDecFromInt(0) + p.count = int64(0) +} + +func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4AvgDecimal)(pr) + if p.count == 0 { + chk.AppendNull(e.ordinal) + return nil + } + decimalCount := types.NewDecFromInt(p.count) + finalResult := new(types.MyDecimal) + err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr) + if err != nil { + return errors.Trace(err) + } + chk.AppendMyDecimal(e.ordinal, finalResult) + return nil +} + +type avgOriginal4Decimal struct { + baseAvgDecimal +} + +func (e *avgOriginal4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgDecimal)(pr) + newSum := new(types.MyDecimal) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDecimal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + err = types.DecimalAdd(&p.sum, input, newSum) + if err != nil { + return errors.Trace(err) + } + p.sum = *newSum + p.count++ + } + return nil +} + +type avgPartial4Decimal struct { + baseAvgDecimal +} + +func (e *avgPartial4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgDecimal)(pr) + newSum := new(types.MyDecimal) + for _, row := range rowsInGroup { + inputSum, isNull, err := e.args[1].EvalDecimal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + + inputCount, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + + err = types.DecimalAdd(&p.sum, inputSum, newSum) + if err != nil { + return errors.Trace(err) + } + p.sum = *newSum + p.count += inputCount + } + return nil +} + +// All the following avg function implementations return the float64 result, +// which store the partial results in "partialResult4AvgFloat64". +// +// "baseAvgFloat64" is wrapped by: +// - "avgOriginal4Float64" +// - "avgPartial4Float64" +type baseAvgFloat64 struct { + baseAggFunc +} + +type partialResult4AvgFloat64 struct { + sum float64 + count int64 +} + +func (e *baseAvgFloat64) AllocPartialResult() PartialResult { + return (PartialResult)(&partialResult4AvgFloat64{}) +} + +func (e *baseAvgFloat64) ResetPartialResult(pr PartialResult) { + p := (*partialResult4AvgFloat64)(pr) + p.sum = 0 + p.count = 0 +} + +func (e *baseAvgFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4AvgFloat64)(pr) + if p.count == 0 { + chk.AppendNull(e.ordinal) + } else { + chk.AppendFloat64(e.ordinal, p.sum/float64(p.count)) + } + return nil +} + +type avgOriginal4Float64 struct { + baseAvgFloat64 +} + +func (e *avgOriginal4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgFloat64)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + p.sum += input + p.count++ + } + return nil +} + +type avgPartial4Float64 struct { + baseAvgFloat64 +} + +func (e *avgPartial4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgFloat64)(pr) + for _, row := range rowsInGroup { + inputSum, isNull, err := e.args[1].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + + inputCount, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + p.sum += inputSum + p.count += inputCount + } + return nil +} + +// All the following avg function implementations return the float32 result, +// which store the partial results in "partialResult4AvgFloat32". +// +// "baseAvgFloat32" is wrapped by: +// - "avgOriginal4Float32" +// - "avgPartial4Float32" +type baseAvgFloat32 struct { + baseAggFunc +} + +type partialResult4AvgFloat32 struct { + sum float32 + count int64 +} + +func (e *baseAvgFloat32) AllocPartialResult() PartialResult { + return (PartialResult)(&partialResult4AvgFloat32{}) +} + +func (e *baseAvgFloat32) ResetPartialResult(pr PartialResult) { + p := (*partialResult4AvgFloat32)(pr) + p.sum = 0 + p.count = 0 +} + +func (e *baseAvgFloat32) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4AvgFloat32)(pr) + if p.count == 0 { + chk.AppendNull(e.ordinal) + } else { + chk.AppendFloat32(e.ordinal, p.sum/float32(p.count)) + } + return nil +} + +type avgOriginal4Float32 struct { + baseAvgFloat32 +} + +func (e *avgOriginal4Float32) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgFloat32)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + + p.sum += float32(input) + p.count++ + } + return nil +} + +type avgPartial4Float32 struct { + baseAvgFloat32 +} + +func (e *avgPartial4Float32) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4AvgFloat32)(pr) + for _, row := range rowsInGroup { + inputSum, isNull, err := e.args[1].EvalReal(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + + inputCount, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return errors.Trace(err) + } else if isNull { + continue + } + p.sum += float32(inputSum) + p.count += inputCount + } + return nil +} From 26deb1a71a5d52493ec0065de7fc1f6485a0d019 Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Mon, 2 Jul 2018 19:03:29 +0800 Subject: [PATCH 2/4] address comment --- executor/aggfuncs/aggfuncs.go | 3 -- executor/aggfuncs/builder.go | 9 +--- executor/aggfuncs/func_avg.go | 84 ----------------------------------- 3 files changed, 1 insertion(+), 95 deletions(-) diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 289501dac13a9..7a4f3ccfe190b 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -32,9 +32,6 @@ var ( _ AggFunc = (*avgOriginal4Float64)(nil) _ AggFunc = (*avgPartial4Float64)(nil) - _ AggFunc = (*avgOriginal4Float32)(nil) - _ AggFunc = (*avgPartial4Float32)(nil) - // All the AggFunc implementations for "FIRSTROW" are listed here. // All the AggFunc implementations for "MAX" are listed here. // All the AggFunc implementations for "MIN" are listed here. diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 9a87c2ba0e873..e6ec86d5609f3 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -78,12 +78,7 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return nil // not implemented yet. } return &avgOriginal4Decimal{baseAvgDecimal{base}} - case mysql.TypeFloat: - if aggFuncDesc.HasDistinct { - return nil // not implemented yet. - } - return &avgOriginal4Float32{baseAvgFloat32{base}} - case mysql.TypeDouble: + case mysql.TypeFloat, mysql.TypeDouble: if aggFuncDesc.HasDistinct { return nil // not implemented yet. } @@ -96,8 +91,6 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { switch aggFuncDesc.Args[1].GetType().Tp { case mysql.TypeNewDecimal: return &avgPartial4Decimal{baseAvgDecimal{base}} - case mysql.TypeFloat: - return &avgPartial4Float32{baseAvgFloat32{base}} case mysql.TypeDouble: return &avgPartial4Float64{baseAvgFloat64{base}} } diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index b428cec770ee7..c58b0eea1223b 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -27,9 +27,6 @@ var ( _ AggFunc = (*avgOriginal4Float64)(nil) _ AggFunc = (*avgPartial4Float64)(nil) - - _ AggFunc = (*avgOriginal4Float32)(nil) - _ AggFunc = (*avgPartial4Float32)(nil) ) // All the following avg function implementations return the decimal result, @@ -208,84 +205,3 @@ func (e *avgPartial4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsIn } return nil } - -// All the following avg function implementations return the float32 result, -// which store the partial results in "partialResult4AvgFloat32". -// -// "baseAvgFloat32" is wrapped by: -// - "avgOriginal4Float32" -// - "avgPartial4Float32" -type baseAvgFloat32 struct { - baseAggFunc -} - -type partialResult4AvgFloat32 struct { - sum float32 - count int64 -} - -func (e *baseAvgFloat32) AllocPartialResult() PartialResult { - return (PartialResult)(&partialResult4AvgFloat32{}) -} - -func (e *baseAvgFloat32) ResetPartialResult(pr PartialResult) { - p := (*partialResult4AvgFloat32)(pr) - p.sum = 0 - p.count = 0 -} - -func (e *baseAvgFloat32) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { - p := (*partialResult4AvgFloat32)(pr) - if p.count == 0 { - chk.AppendNull(e.ordinal) - } else { - chk.AppendFloat32(e.ordinal, p.sum/float32(p.count)) - } - return nil -} - -type avgOriginal4Float32 struct { - baseAvgFloat32 -} - -func (e *avgOriginal4Float32) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { - p := (*partialResult4AvgFloat32)(pr) - for _, row := range rowsInGroup { - input, isNull, err := e.args[0].EvalReal(sctx, row) - if err != nil { - return errors.Trace(err) - } else if isNull { - continue - } - - p.sum += float32(input) - p.count++ - } - return nil -} - -type avgPartial4Float32 struct { - baseAvgFloat32 -} - -func (e *avgPartial4Float32) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { - p := (*partialResult4AvgFloat32)(pr) - for _, row := range rowsInGroup { - inputSum, isNull, err := e.args[1].EvalReal(sctx, row) - if err != nil { - return errors.Trace(err) - } else if isNull { - continue - } - - inputCount, isNull, err := e.args[0].EvalInt(sctx, row) - if err != nil { - return errors.Trace(err) - } else if isNull { - continue - } - p.sum += float32(inputSum) - p.count += inputCount - } - return nil -} From 81bde2850f2884612a456adfe6352912dda370a4 Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Mon, 2 Jul 2018 19:07:46 +0800 Subject: [PATCH 3/4] remove duplicated code --- executor/aggfuncs/func_avg.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index c58b0eea1223b..f638e48035887 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -20,15 +20,6 @@ import ( "github.com/pingcap/tidb/util/chunk" ) -// All the AggFunc implementations for "AVG" are listed here. -var ( - _ AggFunc = (*avgOriginal4Decimal)(nil) - _ AggFunc = (*avgPartial4Decimal)(nil) - - _ AggFunc = (*avgOriginal4Float64)(nil) - _ AggFunc = (*avgPartial4Float64)(nil) -) - // All the following avg function implementations return the decimal result, // which store the partial results in "partialResult4AvgDecimal". // From f110af245811a90ad9fe02a0b570277b3108132a Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Tue, 3 Jul 2018 11:28:05 +0800 Subject: [PATCH 4/4] address comment --- executor/aggfuncs/func_avg.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index f638e48035887..882a859029f08 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -72,9 +72,11 @@ func (e *avgOriginal4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsI input, isNull, err := e.args[0].EvalDecimal(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } + err = types.DecimalAdd(&p.sum, input, newSum) if err != nil { return errors.Trace(err) @@ -96,14 +98,16 @@ func (e *avgPartial4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsIn inputSum, isNull, err := e.args[1].EvalDecimal(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } inputCount, isNull, err := e.args[0].EvalInt(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } @@ -162,9 +166,11 @@ func (e *avgOriginal4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsI input, isNull, err := e.args[0].EvalReal(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } + p.sum += input p.count++ } @@ -181,16 +187,19 @@ func (e *avgPartial4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsIn inputSum, isNull, err := e.args[1].EvalReal(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } inputCount, isNull, err := e.args[0].EvalInt(sctx, row) if err != nil { return errors.Trace(err) - } else if isNull { + } + if isNull { continue } + p.sum += inputSum p.count += inputCount }