diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index 09bcd13dbeafb..6784419e18456 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -382,14 +382,15 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin.Clone()) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, all) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all) } // buildQuantifierPlan adds extra condition for any / all subquery. -func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) { - funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr.Clone()) +func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) { + innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) + outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false) + funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), FromID: plan4Agg.id, @@ -398,30 +399,39 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) plan4Agg.schema.Append(colSum) + innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) + + // Build `count(1)` aggregation to check if subquery is empty. + funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + colCount := &expression.Column{ + ColName: model.NewCIStr("agg_col_cnt"), + FromID: plan4Agg.id, + Position: plan4Agg.schema.Len(), + RetType: funcCount.RetTp, + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) + plan4Agg.schema.Append(colCount) if all { - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull.Clone()}, false) - colCount := &expression.Column{ - ColName: model.NewCIStr("agg_col_cnt"), - FromID: plan4Agg.id, - Position: plan4Agg.schema.Len(), - RetType: funcCount.RetTp, - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) - plan4Agg.schema.Append(colCount) // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it - // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null). - hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum.Clone(), expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null) - cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker) - // If the set is empty, it should always return true. - checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount.Clone(), expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty) + // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return true. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } else { - // For "any" expression, if the record set has null and the cond return false, the result should be NULL. - hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum.Clone(), expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker) + // For "any" expression, if the subquery has null and the cond returns false, the result should be NULL. + // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return false. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should return null. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. @@ -477,7 +487,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count.Clone(), expression.One) neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol.Clone()) cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, false) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false) } // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to @@ -505,7 +515,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count.Clone(), expression.One) eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol.Clone()) cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, true) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true) } func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) { diff --git a/plan/expression_rewriter_test.go b/plan/expression_rewriter_test.go new file mode 100644 index 0000000000000..9219b1f8722f3 --- /dev/null +++ b/plan/expression_rewriter_test.go @@ -0,0 +1,119 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testExpressionRewriterSuite{}) + +type testExpressionRewriterSuite struct { +} + +func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists s") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into t values(1, null), (2, null)") + + // Test empty checker. + tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + // Test outer null checker. + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where a = 2") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + // Test inner null checker. + tk.MustExec("insert into t values(null, 1)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where b = 1") + tk.MustExec("insert into t values(null, 2)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "1", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "0", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows( + " 2", + )) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) +}