From 776c90fa79ac8a6f6dd42635f57a8116759e6be5 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Mon, 21 Jan 2019 16:17:48 +0800 Subject: [PATCH] planner: check null and empty for `!= any(subq)` and `= all(subq)` (#9106) --- cmd/explaintest/r/select.result | 22 ++++++ cmd/explaintest/t/select.test | 6 ++ planner/core/expression_rewriter.go | 60 ++++++++------- planner/core/expression_rewriter_test.go | 94 ++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 25 deletions(-) diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index 069f8ad4a435d..03a5ceb0a6289 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -358,3 +358,25 @@ Union_7 20000.00 root │ └─TableScan_8 10000.00 cop table:th, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo └─TableReader_11 10000.00 root data:TableScan_10 └─TableScan_10 10000.00 cop table:th, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 0)), and(ne(agg_col_cnt, 0), if(isnull(t1.a), NULL, 1))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +explain select a = all (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 1)), or(eq(agg_col_cnt, 0), if(isnull(t1.a), NULL, 0))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/t/select.test b/cmd/explaintest/t/select.test index 5070a52061206..80143354ff88e 100644 --- a/cmd/explaintest/t/select.test +++ b/cmd/explaintest/t/select.test @@ -175,3 +175,9 @@ insert into th values (-1,-1),(-2,-2),(-3,-3),(-4,-4),(-5,-5),(-6,-6),(-7,-7),(- desc select * from th where a=-2; desc select * from th; desc select * from th partition (p2,p1); + +# test != any(subq) and = all(subq) +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +explain select a = all (select a from t t2) from t t1; diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 6b7667cfe1577..09f66f4efd434 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -429,14 +429,15 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, all) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all) } // buildQuantifierPlan adds extra condition for any / all subquery. -func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) { - funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) +func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) { + innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) + outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false) + funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -444,29 +445,38 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) plan4Agg.schema.Append(colSum) + innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) + + // Build `count(1)` aggregation to check if subquery is empty. + funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + colCount := &expression.Column{ + ColName: model.NewCIStr("agg_col_cnt"), + UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: funcCount.RetTp, + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) + plan4Agg.schema.Append(colCount) if all { - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull}, false) - colCount := &expression.Column{ - ColName: model.NewCIStr("agg_col_cnt"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: funcCount.RetTp, - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) - plan4Agg.schema.Append(colCount) // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it - // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null). - hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null) - cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker) - // If the set is empty, it should always return true. - checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty) + // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return true. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } else { - // For "any" expression, if the record set has null and the cond return false, the result should be NULL. - hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker) + // For "any" expression, if the subquery has null and the cond returns false, the result should be NULL. + // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return false. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should return null. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. @@ -519,7 +529,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One) neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, false) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false) } // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to @@ -545,7 +555,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One) eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, true) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true) } func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) { diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 6089ac43f20ea..805c81e7d3449 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -129,3 +129,97 @@ func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) { tk.MustExec("update t1 set c = c + default(c)") tk.MustQuery("select c from t1").Check(testkit.Rows("11")) } + +func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists s") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into t values(1, null), (2, null)") + + // Test empty checker. + tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + // Test outer null checker. + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where a = 2") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + // Test inner null checker. + tk.MustExec("insert into t values(null, 1)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where b = 1") + tk.MustExec("insert into t values(null, 2)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "1", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "0", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows( + " 2", + )) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) +}