diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index 9f3926269dff7..ef3f1be3fd2b9 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -332,3 +332,25 @@ desc select sysdate(), sleep(1), sysdate(); id count task operator info Projection_3 1.00 root sysdate(), sleep(1), sysdate() └─TableDual_4 1.00 root rows:1 +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 0)), and(ne(agg_col_cnt, 0), if(isnull(t1.a), NULL, 1))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +explain select a = all (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 1)), or(eq(agg_col_cnt, 0), if(isnull(t1.a), NULL, 0))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/t/select.test b/cmd/explaintest/t/select.test index d062a7e047a4d..68abcfe508f41 100644 --- a/cmd/explaintest/t/select.test +++ b/cmd/explaintest/t/select.test @@ -165,3 +165,9 @@ desc select * from t where a = 1; desc select * from t where a = '1'; desc select sysdate(), sleep(1), sysdate(); + +# test != any(subq) and = all(subq) +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +explain select a = all (select a from t t2) from t t1; diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 1c5f8f498ae28..6d8d24c6284d6 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -410,14 +410,15 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, all) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all) } // buildQuantifierPlan adds extra condition for any / all subquery. -func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) { - funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) +func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) { + innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) + outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false) + funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -425,29 +426,38 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) plan4Agg.schema.Append(colSum) + innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) + + // Build `count(1)` aggregation to check if subquery is empty. + funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + colCount := &expression.Column{ + ColName: model.NewCIStr("agg_col_cnt"), + UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: funcCount.RetTp, + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) + plan4Agg.schema.Append(colCount) if all { - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull}, false) - colCount := &expression.Column{ - ColName: model.NewCIStr("agg_col_cnt"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: funcCount.RetTp, - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) - plan4Agg.schema.Append(colCount) // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it - // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null). - hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null) - cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker) - // If the set is empty, it should always return true. - checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty) + // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return true. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } else { - // For "any" expression, if the record set has null and the cond return false, the result should be NULL. - hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker) + // For "any" expression, if the subquery has null and the cond returns false, the result should be NULL. + // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return false. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should return null. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. @@ -500,7 +510,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One) neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, false) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false) } // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to @@ -526,7 +536,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One) eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, true) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true) } func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) { diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index bedd1328446ca..61b2624c64e76 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -58,3 +58,97 @@ func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) { tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 ", "1 2 3")) tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 ")) } + +func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists s") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into t values(1, null), (2, null)") + + // Test empty checker. + tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + // Test outer null checker. + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where a = 2") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + // Test inner null checker. + tk.MustExec("insert into t values(null, 1)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where b = 1") + tk.MustExec("insert into t values(null, 2)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "1", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "0", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows( + " 2", + )) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) +}