Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression, planner: push cast down to control function with enum type. (#24542) #30857

Merged
merged 6 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,7 @@ func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
expr = TryPushCastIntoControlFunctionForHybridType(ctx, expr, tp)
var fc functionClass
switch tp.EvalType() {
case types.ETInt:
Expand Down Expand Up @@ -1985,3 +1986,92 @@ func WrapWithCastAsJSON(ctx sessionctx.Context, expr Expression) Expression {
}
return BuildCastFunction(ctx, expr, tp)
}

// TryPushCastIntoControlFunctionForHybridType try to push cast into control function for Hybrid Type.
// If necessary, it will rebuild control function using changed args.
// When a hybrid type is the output of a control function, the result may be as a numeric type to subsequent calculation
// We should perform the `Cast` operation early to avoid using the wrong type for calculation
// For example, the condition `if(1, e, 'a') = 1`, `if` function will output `e` and compare with `1`.
// If the evaltype is ETString, it will get wrong result. So we can rewrite the condition to
// `IfInt(1, cast(e as int), cast('a' as int)) = 1` to get the correct result.
func TryPushCastIntoControlFunctionForHybridType(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
sf, ok := expr.(*ScalarFunction)
if !ok {
return expr
}

var wrapCastFunc func(ctx sessionctx.Context, expr Expression) Expression
switch tp.EvalType() {
case types.ETInt:
wrapCastFunc = WrapWithCastAsInt
case types.ETReal:
wrapCastFunc = WrapWithCastAsReal
default:
return expr
}

isHybrid := func(ft *types.FieldType) bool {
// todo: compatible with mysql control function using bit type. issue 24725
return ft.Hybrid() && ft.Tp != mysql.TypeBit
}

args := sf.GetArgs()
switch sf.FuncName.L {
case ast.If:
if isHybrid(args[1].GetType()) || isHybrid(args[2].GetType()) {
args[1] = wrapCastFunc(ctx, args[1])
args[2] = wrapCastFunc(ctx, args[2])
f, err := funcs[ast.If].getFunction(ctx, args)
if err != nil {
return expr
}
sf.RetType, sf.Function = f.getRetTp(), f
return sf
}
case ast.Case:
hasHybrid := false
for i := 0; i < len(args)-1; i += 2 {
hasHybrid = hasHybrid || isHybrid(args[i+1].GetType())
}
if len(args)%2 == 1 {
hasHybrid = hasHybrid || isHybrid(args[len(args)-1].GetType())
}
if !hasHybrid {
return expr
}

for i := 0; i < len(args)-1; i += 2 {
args[i+1] = wrapCastFunc(ctx, args[i+1])
}
if len(args)%2 == 1 {
args[len(args)-1] = wrapCastFunc(ctx, args[len(args)-1])
}
f, err := funcs[ast.Case].getFunction(ctx, args)
if err != nil {
return expr
}
sf.RetType, sf.Function = f.getRetTp(), f
return sf
case ast.Elt:
hasHybrid := false
for i := 1; i < len(args); i++ {
hasHybrid = hasHybrid || isHybrid(args[i].GetType())
}
if !hasHybrid {
return expr
}

for i := 1; i < len(args); i++ {
args[i] = wrapCastFunc(ctx, args[i])
}
f, err := funcs[ast.Elt].getFunction(ctx, args)
if err != nil {
return expr
}
sf.RetType, sf.Function = f.getRetTp(), f
return sf
default:
return expr
}
return expr
}
73 changes: 73 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9364,6 +9364,79 @@ func (s *testIntegrationSuite) TestIssue29244(c *C) {
tk.MustQuery("select microsecond(a) from t;").Check(testkit.Rows("123500", "123500"))
}

func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) {
defer s.cleanEnv(c)

// issue 23114
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists e;")
tk.MustExec("create table e(e enum('c', 'b', 'a'));")
tk.MustExec("insert into e values ('a'),('b'),('a'),('b');")
tk.MustQuery("select e from e where if(e>1, e, e);").Sort().Check(
testkit.Rows("a", "a", "b", "b"))
tk.MustQuery("select e from e where case e when 1 then e else e end;").Sort().Check(
testkit.Rows("a", "a", "b", "b"))
tk.MustQuery("select e from e where case 1 when e then e end;").Check(testkit.Rows())

tk.MustQuery("select if(e>1,e,e)='a' from e").Sort().Check(
testkit.Rows("0", "0", "1", "1"))
tk.MustQuery("select if(e>1,e,e)=1 from e").Sort().Check(
testkit.Rows("0", "0", "0", "0"))
// if and if
tk.MustQuery("select if(e>2,e,e) and if(e<=2,e,e) from e;").Sort().Check(
testkit.Rows("1", "1", "1", "1"))
tk.MustQuery("select if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e)) from e;").Sort().Check(
testkit.Rows("0", "0", "1", "1"))
tk.MustQuery("select * from e where if(e>2,e,e) and if(e<=2,e,e);").Sort().Check(
testkit.Rows("a", "a", "b", "b"))
tk.MustQuery("select * from e where if(e>2,e,e) and (if(e<3,0,e) or if(e>=2,0,e));").Sort().Check(
testkit.Rows("a", "a"))

// issue 24494
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(a int,b enum(\"b\",\"y\",\"1\"));")
tk.MustExec("insert into t values(0,\"y\"),(1,\"b\"),(null,null),(2,\"1\");")
tk.MustQuery("SELECT count(*) FROM t where if(a,b ,null);").Check(testkit.Rows("2"))

tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(a int,b enum(\"b\"),c enum(\"c\"));")
tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);")
tk.MustQuery("select a from t where if(a=1,b,c)=\"b\";").Check(testkit.Rows("1", "1"))
tk.MustQuery("select a from t where if(a=1,b,c)=\"c\";").Check(testkit.Rows("2", "2"))
tk.MustQuery("select a from t where if(a=1,b,c)=1;").Sort().Check(testkit.Rows("1", "1", "2", "2"))
tk.MustQuery("select a from t where if(a=1,b,c);").Sort().Check(testkit.Rows("1", "1", "2", "2"))

tk.MustExec("drop table if exists e;")
tk.MustExec("create table e(e enum('c', 'b', 'a'));")
tk.MustExec("insert into e values(3)")
tk.MustQuery("select elt(1,e) = 'a' from e").Check(testkit.Rows("1"))
tk.MustQuery("select elt(1,e) = 3 from e").Check(testkit.Rows("1"))
tk.MustQuery("select e from e where elt(1,e)").Check(testkit.Rows("a"))

// test set type
tk.MustExec("drop table if exists s;")
tk.MustExec("create table s(s set('c', 'b', 'a'));")
tk.MustExec("insert into s values ('a'),('b'),('a'),('b');")
tk.MustQuery("select s from s where if(s>1, s, s);").Sort().Check(
testkit.Rows("a", "a", "b", "b"))
tk.MustQuery("select s from s where case s when 1 then s else s end;").Sort().Check(
testkit.Rows("a", "a", "b", "b"))
tk.MustQuery("select s from s where case 1 when s then s end;").Check(testkit.Rows())

tk.MustQuery("select if(s>1,s,s)='a' from s").Sort().Check(
testkit.Rows("0", "0", "1", "1"))
tk.MustQuery("select if(s>1,s,s)=4 from s").Sort().Check(
testkit.Rows("0", "0", "1", "1"))

tk.MustExec("drop table if exists s;")
tk.MustExec("create table s(s set('c', 'b', 'a'));")
tk.MustExec("insert into s values('a')")
tk.MustQuery("select elt(1,s) = 'a' from s").Check(testkit.Rows("1"))
tk.MustQuery("select elt(1,s) = 4 from s").Check(testkit.Rows("1"))
tk.MustQuery("select s from s where elt(1,s)").Check(testkit.Rows("a"))
}

func (s *testIntegrationSuite) TestIssue29513(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
13 changes: 13 additions & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,19 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a
if len(cnfExpres) == 0 {
return p, nil
}
// check expr field types.
for i, expr := range cnfExpres {
if expr.GetType().EvalType() == types.ETString {
tp := &types.FieldType{
Tp: mysql.TypeDouble,
Flag: expr.GetType().Flag,
Flen: mysql.MaxRealWidth,
Decimal: types.UnspecifiedLength,
}
types.SetBinChsClnFlag(tp)
cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx, expr, tp)
}
}
selection.Conditions = cnfExpres
selection.SetChildren(p)
return selection, nil
Expand Down