diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index b599ae081486f..c493aa2112ee9 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -55,6 +55,11 @@ var ( _ AggFunc = (*avgPartial4Float64)(nil) _ AggFunc = (*avgOriginal4DistinctFloat64)(nil) + _ AggFunc = (*sum4DistinctFloat64)(nil) + _ AggFunc = (*sum4DistinctDecimal)(nil) + _ AggFunc = (*sum4Decimal)(nil) + _ AggFunc = (*sum4Float64)(nil) + // All the AggFunc implementations for "GROUP_CONCAT" are listed here. _ AggFunc = (*groupConcatDistinct)(nil) _ AggFunc = (*groupConcat)(nil) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 5f2e59fc3aaab..e15efb6e0c5a3 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -97,7 +97,31 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // buildSum builds the AggFunc implementation for function "SUM". func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { - return nil + base := baseSumAggFunc{ + baseAggFunc: baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }, + } + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + return nil + default: + switch aggFuncDesc.Args[0].GetType().Tp { + case mysql.TypeFloat, mysql.TypeDouble: + if aggFuncDesc.HasDistinct { + return &sum4DistinctFloat64{base} + } + return &sum4Float64{base} + case mysql.TypeNewDecimal: + if aggFuncDesc.HasDistinct { + return &sum4DistinctDecimal{base} + } + return &sum4Decimal{base} + default: + return nil + } + } } // buildAvg builds the AggFunc implementation for function "AVG". diff --git a/executor/aggfuncs/func_sum.go b/executor/aggfuncs/func_sum.go new file mode 100644 index 0000000000000..47b3fbbcd1e0d --- /dev/null +++ b/executor/aggfuncs/func_sum.go @@ -0,0 +1,242 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" +) + +type partialResult4SumFloat64 struct { + val float64 + isNull bool +} + +type partialResult4SumDecimal struct { + val types.MyDecimal + isNull bool +} + +type partialResult4SumDistinctFloat64 struct { + partialResult4SumFloat64 + valSet float64Set +} + +type partialResult4SumDistinctDecimal struct { + partialResult4SumDecimal + valSet decimalSet +} + +type baseSumAggFunc struct { + baseAggFunc +} + +type sum4Float64 struct { + baseSumAggFunc +} + +func (e *sum4Float64) AllocPartialResult() PartialResult { + p := new(partialResult4SumFloat64) + p.isNull = true + return PartialResult(p) +} + +func (e *sum4Float64) ResetPartialResult(pr PartialResult) { + p := (*partialResult4SumFloat64)(pr) + p.val = 0 + p.isNull = true +} + +func (e *sum4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4SumFloat64)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendFloat64(e.ordinal, p.val) + return nil +} + +func (e *sum4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4SumFloat64)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, &row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + if p.isNull { + p.val = input + p.isNull = false + continue + } + p.val += input + } + return nil +} + +type sum4Decimal struct { + baseSumAggFunc +} + +func (e *sum4Decimal) AllocPartialResult() PartialResult { + p := new(partialResult4SumDecimal) + p.isNull = true + return PartialResult(p) +} + +func (e *sum4Decimal) ResetPartialResult(pr PartialResult) { + p := (*partialResult4SumDecimal)(pr) + p.isNull = true +} + +func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4SumDecimal)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendMyDecimal(e.ordinal, &p.val) + return nil +} + +func (e *sum4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4SumDecimal)(pr) + newSum := new(types.MyDecimal) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDecimal(sctx, &row) + if err != nil { + return errors.Trace(err) + } + if isNull { + continue + } + if p.isNull { + p.val = *input + p.isNull = false + continue + } + + err = types.DecimalAdd(&p.val, input, newSum) + if err != nil { + return errors.Trace(err) + } + p.val = *newSum + } + return nil +} + +type sum4DistinctFloat64 struct { + baseSumAggFunc +} + +func (e *sum4DistinctFloat64) AllocPartialResult() PartialResult { + p := new(partialResult4SumDistinctFloat64) + p.isNull = true + p.valSet = newFloat64Set() + return PartialResult(p) +} + +func (e *sum4DistinctFloat64) ResetPartialResult(pr PartialResult) { + p := (*partialResult4SumDistinctFloat64)(pr) + p.isNull = true + p.valSet = newFloat64Set() +} + +func (e *sum4DistinctFloat64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4SumDistinctFloat64)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, &row) + if err != nil { + return errors.Trace(err) + } + if isNull || p.valSet.exist(input) { + continue + } + p.valSet.insert(input) + if p.isNull { + p.val = input + p.isNull = false + continue + } + p.val += input + } + return nil +} + +func (e *sum4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4SumDistinctFloat64)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendFloat64(e.ordinal, p.val) + return nil +} + +type sum4DistinctDecimal struct { + baseSumAggFunc +} + +func (e *sum4DistinctDecimal) AllocPartialResult() PartialResult { + p := new(partialResult4SumDistinctDecimal) + p.isNull = true + p.valSet = newDecimalSet() + return PartialResult(p) +} + +func (e *sum4DistinctDecimal) ResetPartialResult(pr PartialResult) { + p := (*partialResult4SumDistinctDecimal)(pr) + p.isNull = true + p.valSet = newDecimalSet() +} + +func (e *sum4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4SumDistinctDecimal)(pr) + var newSum types.MyDecimal + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDecimal(sctx, &row) + if err != nil { + return errors.Trace(err) + } + if isNull || p.valSet.exist(input) { + continue + } + p.valSet.insert(input) + if p.isNull { + p.val = *input + p.isNull = false + continue + } + if err = types.DecimalAdd(&p.val, input, &newSum); err != nil { + return errors.Trace(err) + } + p.val = newSum + } + return nil +} + +func (e *sum4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4SumDistinctDecimal)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendMyDecimal(e.ordinal, &p.val) + return nil +}