From 87175b26c9ad9806bbe68c85f9dcada438e3da8d Mon Sep 17 00:00:00 2001 From: xuhuaiyu <391585975@qq.com> Date: Mon, 3 Sep 2018 13:01:08 +0800 Subject: [PATCH] address comment --- executor/aggfuncs/func_avg.go | 21 ++--- executor/aggfuncs/func_count.go | 11 +-- executor/aggfuncs/func_group_concat.go | 11 +-- executor/aggfuncs/func_sum.go | 21 ++--- executor/aggfuncs/sets.go | 61 -------------- executor/aggregate.go | 108 ++++++++++++------------- util/set/decimal_set.go | 36 +++++++++ util/set/float64_set.go | 32 ++++++++ util/set/string_set.go | 32 ++++++++ 9 files changed, 186 insertions(+), 147 deletions(-) delete mode 100644 executor/aggfuncs/sets.go create mode 100644 util/set/decimal_set.go create mode 100644 util/set/float64_set.go create mode 100644 util/set/string_set.go diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index 033aaf1f5b84e..d14116efcf4a3 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/set" ) // All the following avg function implementations return the decimal result, @@ -138,7 +139,7 @@ func (e *avgPartial4Decimal) MergePartialResult(sctx sessionctx.Context, src Par type partialResult4AvgDistinctDecimal struct { partialResult4AvgDecimal - valSet decimalSet + valSet set.DecimalSet } type avgOriginal4DistinctDecimal struct { @@ -147,7 +148,7 @@ type avgOriginal4DistinctDecimal struct { func (e *avgOriginal4DistinctDecimal) AllocPartialResult() PartialResult { p := &partialResult4AvgDistinctDecimal{ - valSet: newDecimalSet(), + valSet: set.NewDecimalSet(), } return PartialResult(p) } @@ -156,7 +157,7 @@ func (e *avgOriginal4DistinctDecimal) ResetPartialResult(pr PartialResult) { p := (*partialResult4AvgDistinctDecimal)(pr) p.sum = *types.NewDecFromInt(0) p.count = int64(0) - p.valSet = newDecimalSet() + p.valSet = set.NewDecimalSet() } func (e *avgOriginal4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { @@ -166,7 +167,7 @@ func (e *avgOriginal4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Contex if err != nil { return errors.Trace(err) } - if isNull || p.valSet.exist(input) { + if isNull || p.valSet.Exist(input) { continue } @@ -177,7 +178,7 @@ func (e *avgOriginal4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Contex } p.sum = *newSum p.count++ - p.valSet.insert(input) + p.valSet.Insert(input) } return nil } @@ -291,7 +292,7 @@ func (e *avgPartial4Float64) MergePartialResult(sctx sessionctx.Context, src Par type partialResult4AvgDistinctFloat64 struct { partialResult4AvgFloat64 - valSet float64Set + valSet set.Float64Set } type avgOriginal4DistinctFloat64 struct { @@ -300,7 +301,7 @@ type avgOriginal4DistinctFloat64 struct { func (e *avgOriginal4DistinctFloat64) AllocPartialResult() PartialResult { p := &partialResult4AvgDistinctFloat64{ - valSet: newFloat64Set(), + valSet: set.NewFloat64Set(), } return PartialResult(p) } @@ -309,7 +310,7 @@ func (e *avgOriginal4DistinctFloat64) ResetPartialResult(pr PartialResult) { p := (*partialResult4AvgDistinctFloat64)(pr) p.sum = float64(0) p.count = int64(0) - p.valSet = newFloat64Set() + p.valSet = set.NewFloat64Set() } func (e *avgOriginal4DistinctFloat64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { @@ -319,13 +320,13 @@ func (e *avgOriginal4DistinctFloat64) UpdatePartialResult(sctx sessionctx.Contex if err != nil { return errors.Trace(err) } - if isNull || p.valSet.exist(input) { + if isNull || p.valSet.Exist(input) { continue } p.sum += input p.count++ - p.valSet.insert(input) + p.valSet.Insert(input) } return nil } diff --git a/executor/aggfuncs/func_count.go b/executor/aggfuncs/func_count.go index 8cc682e91200f..7e76f5b7a97c0 100644 --- a/executor/aggfuncs/func_count.go +++ b/executor/aggfuncs/func_count.go @@ -10,6 +10,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/set" ) type baseCount struct { @@ -220,20 +221,20 @@ type countOriginalWithDistinct struct { type partialResult4CountWithDistinct struct { count int64 - valSet stringSet + valSet set.StringSet } func (e *countOriginalWithDistinct) AllocPartialResult() PartialResult { return PartialResult(&partialResult4CountWithDistinct{ count: 0, - valSet: newStringSet(), + valSet: set.NewStringSet(), }) } func (e *countOriginalWithDistinct) ResetPartialResult(pr PartialResult) { p := (*partialResult4CountWithDistinct)(pr) p.count = 0 - p.valSet = newStringSet() + p.valSet = set.NewStringSet() } func (e *countOriginalWithDistinct) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -265,10 +266,10 @@ func (e *countOriginalWithDistinct) UpdatePartialResult(sctx sessionctx.Context, } } encodedString := string(encodedBytes) - if hasNull || p.valSet.exist(encodedString) { + if hasNull || p.valSet.Exist(encodedString) { continue } - p.valSet.insert(encodedString) + p.valSet.Insert(encodedString) p.count++ } diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index a2cf95824f1d0..f3a1b90f3d8b0 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/set" ) type baseGroupConcat4String struct { @@ -132,7 +133,7 @@ func (e *groupConcat) MergePartialResult(sctx sessionctx.Context, src, dst Parti type partialResult4GroupConcatDistinct struct { basePartialResult4GroupConcat valsBuf *bytes.Buffer - valSet stringSet + valSet set.StringSet } type groupConcatDistinct struct { @@ -142,13 +143,13 @@ type groupConcatDistinct struct { func (e *groupConcatDistinct) AllocPartialResult() PartialResult { p := new(partialResult4GroupConcatDistinct) p.valsBuf = &bytes.Buffer{} - p.valSet = newStringSet() + p.valSet = set.NewStringSet() return PartialResult(p) } func (e *groupConcatDistinct) ResetPartialResult(pr PartialResult) { p := (*partialResult4GroupConcatDistinct)(pr) - p.buffer, p.valSet = nil, newStringSet() + p.buffer, p.valSet = nil, set.NewStringSet() } func (e *groupConcatDistinct) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { @@ -170,10 +171,10 @@ func (e *groupConcatDistinct) UpdatePartialResult(sctx sessionctx.Context, rowsI continue } joinedVals := p.valsBuf.String() - if p.valSet.exist(joinedVals) { + if p.valSet.Exist(joinedVals) { continue } - p.valSet.insert(joinedVals) + p.valSet.Insert(joinedVals) // write separator if p.buffer == nil { p.buffer = &bytes.Buffer{} diff --git a/executor/aggfuncs/func_sum.go b/executor/aggfuncs/func_sum.go index 21704b61490f7..e55d882a3d12e 100644 --- a/executor/aggfuncs/func_sum.go +++ b/executor/aggfuncs/func_sum.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/set" ) type partialResult4SumFloat64 struct { @@ -32,12 +33,12 @@ type partialResult4SumDecimal struct { type partialResult4SumDistinctFloat64 struct { partialResult4SumFloat64 - valSet float64Set + valSet set.Float64Set } type partialResult4SumDistinctDecimal struct { partialResult4SumDecimal - valSet decimalSet + valSet set.DecimalSet } type baseSumAggFunc struct { @@ -173,14 +174,14 @@ type sum4DistinctFloat64 struct { func (e *sum4DistinctFloat64) AllocPartialResult() PartialResult { p := new(partialResult4SumDistinctFloat64) p.isNull = true - p.valSet = newFloat64Set() + p.valSet = set.NewFloat64Set() return PartialResult(p) } func (e *sum4DistinctFloat64) ResetPartialResult(pr PartialResult) { p := (*partialResult4SumDistinctFloat64)(pr) p.isNull = true - p.valSet = newFloat64Set() + p.valSet = set.NewFloat64Set() } func (e *sum4DistinctFloat64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { @@ -190,10 +191,10 @@ func (e *sum4DistinctFloat64) UpdatePartialResult(sctx sessionctx.Context, rowsI if err != nil { return errors.Trace(err) } - if isNull || p.valSet.exist(input) { + if isNull || p.valSet.Exist(input) { continue } - p.valSet.insert(input) + p.valSet.Insert(input) if p.isNull { p.val = input p.isNull = false @@ -221,14 +222,14 @@ type sum4DistinctDecimal struct { func (e *sum4DistinctDecimal) AllocPartialResult() PartialResult { p := new(partialResult4SumDistinctDecimal) p.isNull = true - p.valSet = newDecimalSet() + p.valSet = set.NewDecimalSet() return PartialResult(p) } func (e *sum4DistinctDecimal) ResetPartialResult(pr PartialResult) { p := (*partialResult4SumDistinctDecimal)(pr) p.isNull = true - p.valSet = newDecimalSet() + p.valSet = set.NewDecimalSet() } func (e *sum4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { @@ -238,10 +239,10 @@ func (e *sum4DistinctDecimal) UpdatePartialResult(sctx sessionctx.Context, rowsI if err != nil { return errors.Trace(err) } - if isNull || p.valSet.exist(input) { + if isNull || p.valSet.Exist(input) { continue } - p.valSet.insert(input) + p.valSet.Insert(input) if p.isNull { p.val = *input p.isNull = false diff --git a/executor/aggfuncs/sets.go b/executor/aggfuncs/sets.go deleted file mode 100644 index 642b1e5dc1289..0000000000000 --- a/executor/aggfuncs/sets.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package aggfuncs - -import ( - "github.com/pingcap/tidb/types" -) - -type decimalSet map[types.MyDecimal]struct{} -type float64Set map[float64]struct{} -type stringSet map[string]struct{} - -func newDecimalSet() decimalSet { - return make(map[types.MyDecimal]struct{}) -} - -func (s decimalSet) exist(val *types.MyDecimal) bool { - _, ok := s[*val] - return ok -} - -func (s decimalSet) insert(val *types.MyDecimal) { - s[*val] = struct{}{} -} - -func newFloat64Set() float64Set { - return make(map[float64]struct{}) -} - -func (s float64Set) exist(val float64) bool { - _, ok := s[val] - return ok -} - -func (s float64Set) insert(val float64) { - s[val] = struct{}{} -} - -func newStringSet() stringSet { - return make(map[string]struct{}) -} - -func (s stringSet) exist(val string) bool { - _, ok := s[val] - return ok -} - -func (s stringSet) insert(val string) { - s[val] = struct{}{} -} diff --git a/executor/aggregate.go b/executor/aggregate.go index 257b8076ea5ae..9a5e35110acb0 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -14,6 +14,7 @@ package executor import ( + "sort" "sync" "github.com/juju/errors" @@ -25,7 +26,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" - "github.com/pingcap/tidb/util/mvmap" + "github.com/pingcap/tidb/util/set" "github.com/spaolacci/murmur3" "golang.org/x/net/context" ) @@ -73,8 +74,7 @@ type HashAggFinalWorker struct { rowBuffer []types.Datum mutableRow chunk.MutRow partialResultMap aggPartialResultMapper - groupSet *mvmap.MVMap - groupVals [][]byte + groupSet set.StringSet inputCh chan *HashAggIntermData outputCh chan *AfFinalResult finalResultHolderCh chan *chunk.Chunk @@ -137,11 +137,12 @@ type HashAggExec struct { PartialAggFuncs []aggfuncs.AggFunc FinalAggFuncs []aggfuncs.AggFunc partialResultMap aggPartialResultMapper - groupMap *mvmap.MVMap - groupIterator *mvmap.Iterator + groupSet set.StringSet + sortedGroupKey []string + cursor4GroupKey int GroupByItems []expression.Expression groupKey []byte - groupVals [][]byte + groupValDatums []types.Datum // After we support parallel execution for aggregation functions with distinct, // we can remove this attribute. @@ -172,16 +173,16 @@ type HashAggInput struct { // HashAggIntermData indicates the intermediate data of aggregation execution. type HashAggIntermData struct { - groupKeys [][]byte + groupKeys []string cursor int partialResultMap aggPartialResultMapper } // getPartialResultBatch fetches a batch of partial results from HashAggIntermData. -func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, prs [][]aggfuncs.PartialResult, aggFuncs []aggfuncs.AggFunc, maxChunkSize int) (_ [][]aggfuncs.PartialResult, groupKeys [][]byte, reachEnd bool) { +func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, prs [][]aggfuncs.PartialResult, aggFuncs []aggfuncs.AggFunc, maxChunkSize int) (_ [][]aggfuncs.PartialResult, groupKeys []string, reachEnd bool) { keyStart := d.cursor for ; d.cursor < len(d.groupKeys) && len(prs) < maxChunkSize; d.cursor++ { - prs = append(prs, d.partialResultMap[string(d.groupKeys[d.cursor])]) + prs = append(prs, d.partialResultMap[d.groupKeys[d.cursor]]) } if d.cursor == len(d.groupKeys) { reachEnd = true @@ -193,8 +194,7 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, func (e *HashAggExec) Close() error { if e.isUnparallelExec { e.childResult = nil - e.groupMap = nil - e.groupIterator = nil + e.groupSet = nil e.partialResultMap = nil return nil } @@ -232,11 +232,10 @@ func (e *HashAggExec) Open(ctx context.Context) error { } func (e *HashAggExec) initForUnparallelExec() { - e.groupMap = mvmap.NewMVMap() - e.groupIterator = e.groupMap.NewIterator() + e.groupSet = set.NewStringSet() e.partialResultMap = make(aggPartialResultMapper, 0) e.groupKey = make([]byte, 0, 8) - e.groupVals = make([][]byte, 0, 8) + e.groupValDatums = make([]types.Datum, 0, len(e.groupKey)) e.childResult = e.children[0].newChunk() } @@ -271,6 +270,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { globalOutputCh: e.finalOutputCh, partialResultsMap: make(aggPartialResultMapper, 0), groupByItems: e.GroupByItems, + groupValDatums: make([]types.Datum, 0, len(e.GroupByItems)), chk: e.children[0].newChunk(), } @@ -286,8 +286,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { e.finalWorkers[i] = HashAggFinalWorker{ baseHashAggWorker: newBaseHashAggWorker(e.finishCh, e.FinalAggFuncs, e.maxChunkSize), partialResultMap: make(aggPartialResultMapper, 0), - groupSet: mvmap.NewMVMap(), - groupVals: make([][]byte, 0, 8), + groupSet: set.NewStringSet(), inputCh: e.partialOutputChs[i], outputCh: e.finalOutputCh, finalResultHolderCh: make(chan *chunk.Chunk, 1), @@ -357,14 +356,13 @@ func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *s // shuffleIntermData shuffles the intermediate data of partial workers to corresponded final workers. // We only support parallel execution for single-machine, so process of encode and decode can be skipped. func (w *HashAggPartialWorker) shuffleIntermData(sc *stmtctx.StatementContext, finalConcurrency int) { - groupKeysSlice := make([][][]byte, finalConcurrency) + groupKeysSlice := make([][]string, finalConcurrency) for groupKey := range w.partialResultsMap { - groupKeyBytes := []byte(groupKey) - finalWorkerIdx := int(murmur3.Sum32(groupKeyBytes)) % finalConcurrency + finalWorkerIdx := int(murmur3.Sum32([]byte(groupKey))) % finalConcurrency if groupKeysSlice[finalWorkerIdx] == nil { - groupKeysSlice[finalWorkerIdx] = make([][]byte, 0, len(w.partialResultsMap)/finalConcurrency) + groupKeysSlice[finalWorkerIdx] = make([]string, 0, len(w.partialResultsMap)/finalConcurrency) } - groupKeysSlice[finalWorkerIdx] = append(groupKeysSlice[finalWorkerIdx], groupKeyBytes) + groupKeysSlice[finalWorkerIdx] = append(groupKeysSlice[finalWorkerIdx], groupKey) } for i := range groupKeysSlice { @@ -426,7 +424,7 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err err input *HashAggIntermData ok bool intermDataBuffer [][]aggfuncs.PartialResult - groupKeys [][]byte + groupKeys []string sc = sctx.GetSessionVars().StmtCtx ) for { @@ -440,11 +438,11 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err err for reachEnd := false; !reachEnd; { intermDataBuffer, groupKeys, reachEnd = input.getPartialResultBatch(sc, intermDataBuffer[:0], w.aggFuncs, w.maxChunkSize) for i, groupKey := range groupKeys { - if len(w.groupSet.Get(groupKey, w.groupVals[:0])) == 0 { - w.groupSet.Put(groupKey, []byte{}) + if !w.groupSet.Exist(groupKey) { + w.groupSet.Insert(groupKey) } prs := intermDataBuffer[i] - finalPartialResults := w.getPartialResult(sc, groupKey, w.partialResultMap) + finalPartialResults := w.getPartialResult(sc, []byte(groupKey), w.partialResultMap) for j, af := range w.aggFuncs { if err = af.MergePartialResult(sctx, prs[j], finalPartialResults[j]); err != nil { return errors.Trace(err) @@ -456,21 +454,13 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err err } func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) { - groupIter := w.groupSet.NewIterator() result, finished := w.receiveFinalResultHolder() if finished { return } result.Reset() - for { - groupKey, _ := groupIter.Next() - if groupKey == nil { - if result.NumRows() > 0 { - w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} - } - return - } - partialResults := w.getPartialResult(sctx.GetSessionVars().StmtCtx, groupKey, w.partialResultMap) + for groupKey := range w.groupSet { + partialResults := w.getPartialResult(sctx.GetSessionVars().StmtCtx, []byte(groupKey), w.partialResultMap) for i, af := range w.aggFuncs { af.AppendFinalResult2Chunk(sctx, partialResults[i], result) } @@ -486,6 +476,9 @@ func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) { result.Reset() } } + if result.NumRows() > 0 { + w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} + } } func (w *HashAggFinalWorker) receiveFinalResultHolder() (*chunk.Chunk, bool) { @@ -621,22 +614,23 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro if err != nil { return errors.Trace(err) } - if (e.groupMap.Len() == 0) && len(e.GroupByItems) == 0 { + if (len(e.groupSet) == 0) && len(e.GroupByItems) == 0 { // If no groupby and no data, we should add an empty group. // For example: // "select count(c) from t;" should return one row [0] // "select count(c) from t group by c1;" should return empty result set. - e.groupMap.Put([]byte{}, []byte{}) + e.groupSet.Insert("") + e.sortedGroupKey = append(e.sortedGroupKey, "") } e.prepared = true } chk.Reset() - for { - groupKey, _ := e.groupIterator.Next() - if groupKey == nil { - return nil - } - partialResults := e.getPartialResults(groupKey) + + // Since we return e.maxChunkSize rows every time, so we should not traverse + // `groupSet` because of its randomness. + sort.Strings(e.sortedGroupKey) + for ; e.cursor4GroupKey < len(e.sortedGroupKey); e.cursor4GroupKey++ { + partialResults := e.getPartialResults(e.sortedGroupKey[e.cursor4GroupKey]) if len(e.PartialAggFuncs) == 0 { chk.SetNumVirtualRows(chk.NumRows() + 1) } @@ -644,9 +638,11 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk) } if chk.NumRows() == e.maxChunkSize { + e.cursor4GroupKey++ return nil } } + return nil } // execute fetches Chunks from src and update each aggregate function for each row in Chunk. @@ -666,8 +662,9 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) { if err != nil { return errors.Trace(err) } - if len(e.groupMap.Get(groupKey, e.groupVals[:0])) == 0 { - e.groupMap.Put(groupKey, []byte{}) + if !e.groupSet.Exist(groupKey) { + e.groupSet.Insert(groupKey) + e.sortedGroupKey = append(e.sortedGroupKey, groupKey) } partialResults := e.getPartialResults(groupKey) for i, af := range e.PartialAggFuncs { @@ -680,35 +677,34 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) { } } -func (e *HashAggExec) getGroupKey(row chunk.Row) ([]byte, error) { - vals := make([]types.Datum, 0, len(e.GroupByItems)) +func (e *HashAggExec) getGroupKey(row chunk.Row) (string, error) { + e.groupValDatums = e.groupValDatums[:0] for _, item := range e.GroupByItems { v, err := item.Eval(row) if item.GetType().Tp == mysql.TypeNewDecimal { v.SetLength(0) } if err != nil { - return nil, errors.Trace(err) + return "", errors.Trace(err) } - vals = append(vals, v) + e.groupValDatums = append(e.groupValDatums, v) } var err error - e.groupKey, err = codec.EncodeValue(e.sc, e.groupKey[:0], vals...) + e.groupKey, err = codec.EncodeValue(e.sc, e.groupKey[:0], e.groupValDatums...) if err != nil { - return nil, errors.Trace(err) + return "", errors.Trace(err) } - return e.groupKey, nil + return string(e.groupKey), nil } -func (e *HashAggExec) getPartialResults(groupKey []byte) []aggfuncs.PartialResult { - groupKeyString := string(groupKey) - partialResults, ok := e.partialResultMap[groupKeyString] +func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult { + partialResults, ok := e.partialResultMap[groupKey] if !ok { partialResults = make([]aggfuncs.PartialResult, 0, len(e.PartialAggFuncs)) for _, af := range e.PartialAggFuncs { partialResults = append(partialResults, af.AllocPartialResult()) } - e.partialResultMap[groupKeyString] = partialResults + e.partialResultMap[groupKey] = partialResults } return partialResults } diff --git a/util/set/decimal_set.go b/util/set/decimal_set.go new file mode 100644 index 0000000000000..4fb1059564b0c --- /dev/null +++ b/util/set/decimal_set.go @@ -0,0 +1,36 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +import ( + "github.com/pingcap/tidb/types" +) + +type DecimalSet map[types.MyDecimal]struct{} + +// NewDecimalSet builds a decimal set. +func NewDecimalSet() DecimalSet { + return make(map[types.MyDecimal]struct{}) +} + +// Exist checks whether `val` exists in `s`. +func (s DecimalSet) Exist(val *types.MyDecimal) bool { + _, ok := s[*val] + return ok +} + +// Insert inserts `val` into `s`. +func (s DecimalSet) Insert(val *types.MyDecimal) { + s[*val] = struct{}{} +} diff --git a/util/set/float64_set.go b/util/set/float64_set.go new file mode 100644 index 0000000000000..ebdc92f727ea1 --- /dev/null +++ b/util/set/float64_set.go @@ -0,0 +1,32 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +type Float64Set map[float64]struct{} + +// NewFloat64Set builds a float64 set. +func NewFloat64Set() Float64Set { + return make(map[float64]struct{}) +} + +// Exist checks whether `val` exists in `s`. +func (s Float64Set) Exist(val float64) bool { + _, ok := s[val] + return ok +} + +// Insert inserts `val` into `s`. +func (s Float64Set) Insert(val float64) { + s[val] = struct{}{} +} diff --git a/util/set/string_set.go b/util/set/string_set.go new file mode 100644 index 0000000000000..a732c17c1d610 --- /dev/null +++ b/util/set/string_set.go @@ -0,0 +1,32 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +type StringSet map[string]struct{} + +// NewString64Set builds a float64 set. +func NewStringSet() StringSet { + return make(map[string]struct{}) +} + +// Exist checks whether `val` exists in `s`. +func (s StringSet) Exist(val string) bool { + _, ok := s[val] + return ok +} + +// Insert inserts `val` into `s`. +func (s StringSet) Insert(val string) { + s[val] = struct{}{} +}