diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 3a78f317768a2..7be3731fbe3fa 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1811,6 +1811,7 @@ func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types. // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + expr = TryPushCastIntoControlFunctionForHybridType(ctx, expr, tp) var fc functionClass switch tp.EvalType() { case types.ETInt: @@ -1985,3 +1986,92 @@ func WrapWithCastAsJSON(ctx sessionctx.Context, expr Expression) Expression { } return BuildCastFunction(ctx, expr, tp) } + +// TryPushCastIntoControlFunctionForHybridType try to push cast into control function for Hybrid Type. +// If necessary, it will rebuild control function using changed args. +// When a hybrid type is the output of a control function, the result may be as a numeric type to subsequent calculation +// We should perform the `Cast` operation early to avoid using the wrong type for calculation +// For example, the condition `if(1, e, 'a') = 1`, `if` function will output `e` and compare with `1`. +// If the evaltype is ETString, it will get wrong result. So we can rewrite the condition to +// `IfInt(1, cast(e as int), cast('a' as int)) = 1` to get the correct result. +func TryPushCastIntoControlFunctionForHybridType(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + sf, ok := expr.(*ScalarFunction) + if !ok { + return expr + } + + var wrapCastFunc func(ctx sessionctx.Context, expr Expression) Expression + switch tp.EvalType() { + case types.ETInt: + wrapCastFunc = WrapWithCastAsInt + case types.ETReal: + wrapCastFunc = WrapWithCastAsReal + default: + return expr + } + + isHybrid := func(ft *types.FieldType) bool { + // todo: compatible with mysql control function using bit type. issue 24725 + return ft.Hybrid() && ft.Tp != mysql.TypeBit + } + + args := sf.GetArgs() + switch sf.FuncName.L { + case ast.If: + if isHybrid(args[1].GetType()) || isHybrid(args[2].GetType()) { + args[1] = wrapCastFunc(ctx, args[1]) + args[2] = wrapCastFunc(ctx, args[2]) + f, err := funcs[ast.If].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + } + case ast.Case: + hasHybrid := false + for i := 0; i < len(args)-1; i += 2 { + hasHybrid = hasHybrid || isHybrid(args[i+1].GetType()) + } + if len(args)%2 == 1 { + hasHybrid = hasHybrid || isHybrid(args[len(args)-1].GetType()) + } + if !hasHybrid { + return expr + } + + for i := 0; i < len(args)-1; i += 2 { + args[i+1] = wrapCastFunc(ctx, args[i+1]) + } + if len(args)%2 == 1 { + args[len(args)-1] = wrapCastFunc(ctx, args[len(args)-1]) + } + f, err := funcs[ast.Case].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + case ast.Elt: + hasHybrid := false + for i := 1; i < len(args); i++ { + hasHybrid = hasHybrid || isHybrid(args[i].GetType()) + } + if !hasHybrid { + return expr + } + + for i := 1; i < len(args); i++ { + args[i] = wrapCastFunc(ctx, args[i]) + } + f, err := funcs[ast.Elt].getFunction(ctx, args) + if err != nil { + return expr + } + sf.RetType, sf.Function = f.getRetTp(), f + return sf + default: + return expr + } + return expr +} diff --git a/expression/integration_test.go b/expression/integration_test.go index 7031fc13287a6..57eefcdc72804 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9364,6 +9364,79 @@ func (s *testIntegrationSuite) TestIssue29244(c *C) { tk.MustQuery("select microsecond(a) from t;").Check(testkit.Rows("123500", "123500")) } +func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) { + defer s.cleanEnv(c) + + // issue 23114 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists e;") + tk.MustExec("create table e(e enum('c', 'b', 'a'));") + tk.MustExec("insert into e values ('a'),('b'),('a'),('b');") + tk.MustQuery("select e from e where if(e>1, e, e);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select e from e where case e when 1 then e else e end;").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select e from e where case 1 when e then e end;").Check(testkit.Rows()) + + tk.MustQuery("select if(e>1,e,e)='a' from e").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select if(e>1,e,e)=1 from e").Sort().Check( + testkit.Rows("0", "0", "0", "0")) + // if and if + tk.MustQuery("select if(e>2,e,e) and if(e<=2,e,e) from e;").Sort().Check( + testkit.Rows("1", "1", "1", "1")) + tk.MustQuery("select if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e)) from e;").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select * from e where if(e>2,e,e) and if(e<=2,e,e);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select * from e where if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e));").Sort().Check( + testkit.Rows("a", "a")) + + // issue 24494 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int,b enum(\"b\",\"y\",\"1\"));") + tk.MustExec("insert into t values(0,\"y\"),(1,\"b\"),(null,null),(2,\"1\");") + tk.MustQuery("SELECT count(*) FROM t where if(a,b ,null);").Check(testkit.Rows("2")) + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int,b enum(\"b\"),c enum(\"c\"));") + tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);") + tk.MustQuery("select a from t where if(a=1,b,c)=\"b\";").Check(testkit.Rows("1", "1")) + tk.MustQuery("select a from t where if(a=1,b,c)=\"c\";").Check(testkit.Rows("2", "2")) + tk.MustQuery("select a from t where if(a=1,b,c)=1;").Sort().Check(testkit.Rows("1", "1", "2", "2")) + tk.MustQuery("select a from t where if(a=1,b,c);").Sort().Check(testkit.Rows("1", "1", "2", "2")) + + tk.MustExec("drop table if exists e;") + tk.MustExec("create table e(e enum('c', 'b', 'a'));") + tk.MustExec("insert into e values(3)") + tk.MustQuery("select elt(1,e) = 'a' from e").Check(testkit.Rows("1")) + tk.MustQuery("select elt(1,e) = 3 from e").Check(testkit.Rows("1")) + tk.MustQuery("select e from e where elt(1,e)").Check(testkit.Rows("a")) + + // test set type + tk.MustExec("drop table if exists s;") + tk.MustExec("create table s(s set('c', 'b', 'a'));") + tk.MustExec("insert into s values ('a'),('b'),('a'),('b');") + tk.MustQuery("select s from s where if(s>1, s, s);").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select s from s where case s when 1 then s else s end;").Sort().Check( + testkit.Rows("a", "a", "b", "b")) + tk.MustQuery("select s from s where case 1 when s then s end;").Check(testkit.Rows()) + + tk.MustQuery("select if(s>1,s,s)='a' from s").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + tk.MustQuery("select if(s>1,s,s)=4 from s").Sort().Check( + testkit.Rows("0", "0", "1", "1")) + + tk.MustExec("drop table if exists s;") + tk.MustExec("create table s(s set('c', 'b', 'a'));") + tk.MustExec("insert into s values('a')") + tk.MustQuery("select elt(1,s) = 'a' from s").Check(testkit.Rows("1")) + tk.MustQuery("select elt(1,s) = 4 from s").Check(testkit.Rows("1")) + tk.MustQuery("select s from s where elt(1,s)").Check(testkit.Rows("a")) +} + func (s *testIntegrationSuite) TestIssue29513(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 0fd2db36fd10c..a95cba8fe2dce 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -977,6 +977,19 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a if len(cnfExpres) == 0 { return p, nil } + // check expr field types. + for i, expr := range cnfExpres { + if expr.GetType().EvalType() == types.ETString { + tp := &types.FieldType{ + Tp: mysql.TypeDouble, + Flag: expr.GetType().Flag, + Flen: mysql.MaxRealWidth, + Decimal: types.UnspecifiedLength, + } + types.SetBinChsClnFlag(tp) + cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx, expr, tp) + } + } selection.Conditions = cnfExpres selection.SetChildren(p) return selection, nil