Skip to content

Commit

Permalink
executor: fix aggregating enum zero value gets different results from…
Browse files Browse the repository at this point in the history
… mysql (#36208)

close #26885
  • Loading branch information
ywqzzy authored Jul 22, 2022
1 parent 0b1ad27 commit 065563a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
14 changes: 14 additions & 0 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,24 @@ func getGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte,

for _, item := range groupByItems {
tp := item.GetType()

buf, err := expression.GetColumn(tp.EvalType(), numRows)
if err != nil {
return nil, err
}

// In strict sql mode like ‘STRICT_TRANS_TABLES’,can not insert an invalid enum value like 0.
// While in sql mode like '', can insert an invalid enum value like 0,
// then the enum value 0 will have the enum name '', which maybe conflict with user defined enum ''.
// Ref to issue #26885.
// This check is used to handle invalid enum name same with user defined enum name.
// Use enum value as groupKey instead of enum name.
if item.GetType().GetType() == mysql.TypeEnum {
newTp := *tp
newTp.AddFlag(mysql.EnumSetAsIntFlag)
tp = &newTp
}

if err := expression.EvalExpr(ctx, item, tp.EvalType(), input, buf); err != nil {
expression.PutColumn(buf)
return nil, err
Expand All @@ -580,6 +593,7 @@ func getGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte,
newTp.SetFlen(0)
tp = &newTp
}

groupKey, err = codec.HashGroupKey(ctx.GetSessionVars().StmtCtx, input.NumRows(), buf, groupKey, tp)
if err != nil {
expression.PutColumn(buf)
Expand Down
33 changes: 33 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1643,3 +1643,36 @@ func TestIssue27751(t *testing.T) {
tk.MustQuery("select group_concat(nname order by 1 separator '#' ) from t;").Check(testkit.Rows("11#1"))
tk.MustQuery("select group_concat(nname order by 1 desc separator '#' ) from t;").Check(testkit.Rows("33#2"))
}

func TestIssue26885(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)
tk.MustExec(`SET sql_mode = 'NO_ENGINE_SUBSTITUTION';`)
tk.MustExec(`DROP TABLE IF EXISTS t1;`)

tk.MustExec("CREATE TABLE t1 (c1 ENUM('a', '', 'b'));")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('b');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('a');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('');")
tk.MustExec("INSERT INTO t1 (c1) VALUES (0);")
tk.MustQuery("select * from t1").Check(testkit.Rows("b", "", "a", "", ""))
tk.MustQuery("select c1 + 0 from t1").Check(testkit.Rows("3", "2", "1", "2", "0"))
tk.MustQuery("SELECT c1 + 0, COUNT(c1) FROM t1 GROUP BY c1 order by c1;").Check(testkit.Rows("0 1", "1 1", "2 2", "3 1"))

tk.MustExec("alter table t1 add index idx(c1); ")
tk.MustQuery("select c1 + 0 from t1").Check(testkit.Rows("3", "2", "1", "2", "0"))
tk.MustQuery("SELECT c1 + 0, COUNT(c1) FROM t1 GROUP BY c1 order by c1;").Check(testkit.Rows("0 1", "1 1", "2 2", "3 1"))

tk.MustExec(`DROP TABLE IF EXISTS t1;`)
tk.MustExec("CREATE TABLE t1 (c1 ENUM('a', 'b', 'c'));")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('b');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('a');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('b');")
tk.MustExec("INSERT INTO t1 (c1) VALUES ('c');")
tk.MustExec("INSERT INTO t1 (c1) VALUES (0);")
tk.MustQuery("select * from t1").Check(testkit.Rows("b", "a", "b", "c", ""))
tk.MustQuery("SELECT c1 + 0, COUNT(c1) FROM t1 GROUP BY c1 order by c1;").Check(testkit.Rows("0 1", "1 1", "2 2", "3 1"))
}

0 comments on commit 065563a

Please sign in to comment.