Skip to content

Commit

Permalink
plan: check null and empty for != any(subq) and = all(subq) (#9106)…
Browse files Browse the repository at this point in the history
… (#9404)
  • Loading branch information
eurekaka authored and jackysp committed Feb 21, 2019
1 parent a4113cd commit 6be3788
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 26 deletions.
62 changes: 36 additions & 26 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,15 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.
plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}

cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin.Clone())
er.buildQuantifierPlan(plan4Agg, cond, rexpr, all)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all)
}

// buildQuantifierPlan adds extra condition for any / all subquery.
func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) {
funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr.Clone())
func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) {
innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr)

funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false)
funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false)
colSum := &expression.Column{
ColName: model.NewCIStr("agg_col_sum"),
FromID: plan4Agg.id,
Expand All @@ -398,30 +399,39 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
plan4Agg.schema.Append(colSum)
innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero)

// Build `count(1)` aggregation to check if subquery is empty.
funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false)
colCount := &expression.Column{
ColName: model.NewCIStr("agg_col_cnt"),
FromID: plan4Agg.id,
Position: plan4Agg.schema.Len(),
RetType: funcCount.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
plan4Agg.schema.Append(colCount)

if all {
funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull.Clone()}, false)
colCount := &expression.Column{
ColName: model.NewCIStr("agg_col_cnt"),
FromID: plan4Agg.id,
Position: plan4Agg.schema.Len(),
RetType: funcCount.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
plan4Agg.schema.Append(colCount)
// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null).
hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum.Clone(), expression.Zero)
nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null)
cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker)
// If the set is empty, it should always return true.
checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount.Clone(), expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty)
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One)
cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker)
// If the subquery is empty, it should always return true.
emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero)
// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker)
} else {
// For "any" expression, if the record set has null and the cond return false, the result should be NULL.
hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum.Clone(), expression.Zero)
nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker)
// For "any" expression, if the subquery has null and the cond returns false, the result should be NULL.
// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero)
cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker)
// If the subquery is empty, it should always return false.
emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero)
// If outer key is null, and subquery is not empty, it should return null.
outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One)
cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker)
}

// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
Expand Down Expand Up @@ -477,7 +487,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count.Clone(), expression.One)
neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol.Clone())
cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond)
er.buildQuantifierPlan(plan4Agg, cond, rexpr, false)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false)
}

// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
Expand Down Expand Up @@ -505,7 +515,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count.Clone(), expression.One)
eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol.Clone())
cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond)
er.buildQuantifierPlan(plan4Agg, cond, rexpr, true)
er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true)
}

func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
Expand Down
119 changes: 119 additions & 0 deletions plan/expression_rewriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package plan_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
)

var _ = Suite(&testExpressionRewriterSuite{})

type testExpressionRewriterSuite struct {
}

func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists s")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("create table s(a int, b int)")
tk.MustExec("insert into t values(1, null), (2, null)")

// Test empty checker.
tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows(
"0",
"0",
))
tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows(
"0",
"0",
))
tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows(
"1",
"1",
))
tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows(
"1",
"1",
))
tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows())
tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows())
tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows(
"1 <nil>",
"2 <nil>",
))
tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows(
"1 <nil>",
"2 <nil>",
))
// Test outer null checker.
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

tk.MustExec("delete from t where a = 2")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

// Test inner null checker.
tk.MustExec("insert into t values(null, 1)")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"<nil>",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows())
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())

tk.MustExec("delete from t where b = 1")
tk.MustExec("insert into t values(null, 2)")
tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"1",
))
tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows(
"<nil>",
"0",
))
tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows(
"<nil> 2",
))
tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows())
}

0 comments on commit 6be3788

Please sign in to comment.