Skip to content

Commit

Permalink
planner: check null and empty for != any(subq) and = all(subq) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka authored and alivxxx committed Jan 21, 2019
1 parent c371e66 commit 776c90f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 25 deletions.
22 changes: 22 additions & 0 deletions cmd/explaintest/r/select.result
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,25 @@ Union_7 20000.00 root
│ └─TableScan_8 10000.00 cop table:th, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo
└─TableReader_11 10000.00 root data:TableScan_10
└─TableScan_10 10000.00 cop table:th, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo
drop table if exists t;
create table t(a int, b int);
explain select a != any (select a from t t2) from t t1;
id count task operator info
Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 0)), and(ne(agg_col_cnt, 0), if(isnull(t1.a), NULL, 1)))
└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17
├─TableReader_13 10000.00 root data:TableScan_12
│ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo
└─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1)
└─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a))
└─TableReader_24 10000.00 root data:TableScan_23
└─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo
explain select a = all (select a from t t2) from t t1;
id count task operator info
Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 1)), or(eq(agg_col_cnt, 0), if(isnull(t1.a), NULL, 0)))
└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17
├─TableReader_13 10000.00 root data:TableScan_12
│ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo
└─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1)
└─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a))
└─TableReader_24 10000.00 root data:TableScan_23
└─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo
6 changes: 6 additions & 0 deletions cmd/explaintest/t/select.test
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,9 @@ insert into th values (-1,-1),(-2,-2),(-3,-3),(-4,-4),(-5,-5),(-6,-6),(-7,-7),(-
desc select * from th where a=-2;
desc select * from th;
desc select * from th partition (p2,p1);

# test != any(subq) and = all(subq)
drop table if exists t;
create table t(a int, b int);
explain select a != any (select a from t t2) from t t1;
explain select a = all (select a from t t2) from t t1;
60 changes: 35 additions & 25 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,44 +429,54 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.
plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}

cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin)
er.buildQuantifierPlan(plan4Agg, cond, rexpr, all)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all)
}

// buildQuantifierPlan adds extra condition for any / all subquery.
func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) {
funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) {
innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr)

funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false)
funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false)
colSum := &expression.Column{
ColName: model.NewCIStr("agg_col_sum"),
UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: funcSum.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
plan4Agg.schema.Append(colSum)
innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero)

// Build `count(1)` aggregation to check if subquery is empty.
funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false)
colCount := &expression.Column{
ColName: model.NewCIStr("agg_col_cnt"),
UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: funcCount.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
plan4Agg.schema.Append(colCount)

if all {
funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull}, false)
colCount := &expression.Column{
ColName: model.NewCIStr("agg_col_cnt"),
UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: funcCount.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
plan4Agg.schema.Append(colCount)
// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null).
hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero)
nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null)
cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker)
// If the set is empty, it should always return true.
checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty)
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One)
cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker)
// If the subquery is empty, it should always return true.
emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero)
// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker)
} else {
// For "any" expression, if the record set has null and the cond return false, the result should be NULL.
hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero)
nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker)
// For "any" expression, if the subquery has null and the cond returns false, the result should be NULL.
// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker)
// If the subquery is empty, it should always return false.
emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero)
// If outer key is null, and subquery is not empty, it should return null.
outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One)
cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker)
}

// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
Expand Down Expand Up @@ -519,7 +529,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One)
neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol)
cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond)
er.buildQuantifierPlan(plan4Agg, cond, rexpr, false)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false)
}

// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
Expand All @@ -545,7 +555,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One)
eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol)
cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond)
er.buildQuantifierPlan(plan4Agg, cond, rexpr, true)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true)
}

func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
Expand Down
94 changes: 94 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,97 @@ func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) {
tk.MustExec("update t1 set c = c + default(c)")
tk.MustQuery("select c from t1").Check(testkit.Rows("11"))
}

func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists s")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("create table s(a int, b int)")
tk.MustExec("insert into t values(1, null), (2, null)")

// Test empty checker.
tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows(
"0",
"0",
))
tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows(
"0",
"0",
))
tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows(
"1",
"1",
))
tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows(
"1",
"1",
))
tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows())
tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows())
tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows(
"1 <nil>",
"2 <nil>",
))
tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows(
"1 <nil>",
"2 <nil>",
))
// Test outer null checker.
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

tk.MustExec("delete from t where a = 2")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

// Test inner null checker.
tk.MustExec("insert into t values(null, 1)")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

tk.MustExec("delete from t where b = 1")
tk.MustExec("insert into t values(null, 2)")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"1",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"0",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows(
"<nil> 2",
))
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())
}

0 comments on commit 776c90f

Please sign in to comment.