From 7de83a3c8f11571d5ac4821541e5df1f84f38b70 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Wed, 12 Sep 2018 20:29:50 +0800 Subject: [PATCH] plan: use the inferred type as the column type in the schema (#7624) --- executor/aggregate_test.go | 11 +++++++++++ plan/logical_plan_builder.go | 4 +++- plan/rule_aggregation_push_down.go | 4 +++- plan/rule_decorrelate.go | 11 +++++++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 3ce424ca18106..0ef14aefa4d39 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -602,3 +602,14 @@ func (s *testSuite) TestBuildProjBelowAgg(c *C) { "3 3 15 6,6,6 7", "4 3 18 7,7,7 8")) } + +func (s *testSuite) TestFirstRowEnum(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a enum('a', 'b'));`) + tk.MustExec(`insert into t values('a');`) + tk.MustQuery(`select a from t group by a;`).Check(testkit.Rows( + `a`, + )) +} diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 1d1711fd331ed..bb1888d6c85be 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -115,7 +115,9 @@ func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega for _, col := range p.Schema().Columns { newFunc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - schema4Agg.Append(col) + newCol, _ := col.Clone().(*expression.Column) + newCol.RetType = newFunc.RetTp + schema4Agg.Append(newCol) } plan4Agg.SetChildren(p) plan4Agg.GroupByItems = gbyItems diff --git a/plan/rule_aggregation_push_down.go b/plan/rule_aggregation_push_down.go index 3089f1d5c9cfd..0cd8d2c0fbc94 100644 --- a/plan/rule_aggregation_push_down.go +++ b/plan/rule_aggregation_push_down.go @@ -257,8 +257,10 @@ func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*ag } for _, gbyCol := range gbyCols { firstRow := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + newCol, _ := gbyCol.Clone().(*expression.Column) + newCol.RetType = firstRow.RetTp newAggFuncDescs = append(newAggFuncDescs, firstRow) - schema.Append(gbyCol) + schema.Append(newCol) } agg.AggFuncs = newAggFuncDescs agg.SetSchema(schema) diff --git a/plan/rule_decorrelate.go b/plan/rule_decorrelate.go index 9ea443146a67f..70dedc22d5d08 100644 --- a/plan/rule_decorrelate.go +++ b/plan/rule_decorrelate.go @@ -182,13 +182,19 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { agg.SetSchema(apply.Schema()) agg.GroupByItems = expression.Column2Exprs(outerPlan.Schema().Keys[0]) newAggFuncs := make([]*aggregation.AggFuncDesc, 0, apply.Schema().Len()) - for _, col := range outerPlan.Schema().Columns { + + outerColsInSchema := make([]*expression.Column, 0, outerPlan.Schema().Len()) + for i, col := range outerPlan.Schema().Columns { first := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) newAggFuncs = append(newAggFuncs, first) + + outerCol, _ := outerPlan.Schema().Columns[i].Clone().(*expression.Column) + outerCol.RetType = first.RetTp + outerColsInSchema = append(outerColsInSchema, outerCol) } newAggFuncs = append(newAggFuncs, agg.AggFuncs...) agg.AggFuncs = newAggFuncs - apply.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) + apply.SetSchema(expression.MergeSchema(expression.NewSchema(outerColsInSchema...), innerPlan.Schema())) np, err := s.optimize(p) if err != nil { return nil, errors.Trace(err) @@ -229,6 +235,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { newFunc := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) agg.AggFuncs = append(agg.AggFuncs, newFunc) agg.schema.Append(clonedCol.(*expression.Column)) + agg.schema.Columns[agg.schema.Len()-1].RetType = newFunc.RetTp } // If group by cols don't contain the join key, add it into this. if agg.getGbyColIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 {