diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index c7deb8c06aebb..c9b307b1c9bf5 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -307,6 +307,50 @@ Projection_11 10000.00 root 9_aux_0 │ └─TableScan_27 10000.00 cop table:t, keep order:false, stats:pseudo └─TableReader_32 1.00 root data:TableScan_31 └─TableScan_31 1.00 cop table:t1, range: decided by [s.c], keep order:false, stats:pseudo +insert into t values(1, 1, 1), (2, 2 ,2), (3, 3, 3), (4, 3, 4),(5,3,5); +analyze table t; +explain select t.c in (select count(*) from t s, t t1 where s.b = t.a and s.b = 3 and s.a = t1.a) from t; +id count task operator info +Projection_11 5.00 root 9_aux_0 +└─Apply_13 5.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, count(*)) + ├─TableReader_15 5.00 root data:TableScan_14 + │ └─TableScan_14 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_20 1.00 root funcs:count(1) + └─IndexJoin_49 2.40 root inner join, inner:TableReader_48, outer key:s.a, inner key:t1.a + ├─IndexReader_41 2.40 root index:Selection_40 + │ └─Selection_40 2.40 cop eq(3, test.t.a) + │ └─IndexScan_39 3.00 cop table:s, index:b, range:[3,3], keep order:false + └─TableReader_48 0.80 root data:Selection_47 + └─Selection_47 0.80 cop eq(3, test.t.a) + └─TableScan_46 1.00 cop table:t1, range: decided by [s.a], keep order:false +explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a where 3 = t.a and s.b = 3) from t; +id count task operator info +Projection_10 5.00 root 9_aux_0 +└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, count(*)) + ├─TableReader_14 5.00 root data:TableScan_13 + │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_19 1.00 root funcs:count(1) + └─IndexJoin_43 2.40 root left outer join, inner:TableReader_42, outer key:s.a, inner key:t1.a + ├─IndexReader_35 2.40 root index:Selection_34 + │ └─Selection_34 2.40 cop eq(3, test.t.a) + │ └─IndexScan_33 3.00 cop table:s, index:b, range:[3,3], keep order:false + └─TableReader_42 0.80 root data:Selection_41 + └─Selection_41 0.80 cop eq(3, test.t.a) + └─TableScan_40 1.00 cop table:t1, range: decided by [s.a], keep order:false +explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t; +id count task operator info +Projection_10 5.00 root 9_aux_0 +└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, count(*)) + ├─TableReader_14 5.00 root data:TableScan_13 + │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_19 1.00 root funcs:count(1) + └─IndexJoin_43 2.40 root right outer join, inner:TableReader_42, outer key:t1.a, inner key:s.a + ├─TableReader_42 0.80 root data:Selection_41 + │ └─Selection_41 0.80 cop eq(3, test.t.a) + │ └─TableScan_40 1.00 cop table:s, range: decided by [t1.a], keep order:false + └─IndexReader_35 2.40 root index:Selection_34 + └─Selection_34 2.40 cop eq(3, test.t.a) + └─IndexScan_33 3.00 cop table:t1, index:b, range:[3,3], keep order:false drop table if exists t; create table t(a int unsigned); explain select t.a = '123455' from t; diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index 9ef0bc6503efc..f053c3f64d752 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -52,6 +52,12 @@ explain select t.c in (select count(*) from t s ignore index(idx), t t1 where s. explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t; explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t; +insert into t values(1, 1, 1), (2, 2 ,2), (3, 3, 3), (4, 3, 4),(5,3,5); +analyze table t; +explain select t.c in (select count(*) from t s, t t1 where s.b = t.a and s.b = 3 and s.a = t1.a) from t; +explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a where 3 = t.a and s.b = 3) from t; +explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t; + drop table if exists t; create table t(a int unsigned); explain select t.a = '123455' from t; diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 0a5b061076eb7..80021de840a36 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -176,12 +176,42 @@ func (b *planBuilder) buildResultSetNode(node ast.ResultSetNode) (p LogicalPlan, } } +// pushDownConstExpr checks if the condition is from filter condition, if true, push it down to both +// children of join, whatever the join type is; if false, push it down to inner child of outer join, +// and both children of non-outer-join. +func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []expression.Expression, + rightCond []expression.Expression, filterCond bool) ([]expression.Expression, []expression.Expression) { + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + // Append the expr to right join condition instead of `rightCond`, to make it able to be + // pushed down to children of join. + p.RightConditions = append(p.RightConditions, expr) + } else { + rightCond = append(rightCond, expr) + } + case RightOuterJoin: + if filterCond { + rightCond = append(rightCond, expr) + p.LeftConditions = append(p.LeftConditions, expr) + } else { + leftCond = append(leftCond, expr) + } + case SemiJoin, AntiSemiJoin, InnerJoin: + leftCond = append(leftCond, expr) + rightCond = append(rightCond, expr) + } + return leftCond, rightCond +} + // extractOnCondition divide conditions in CNF of join node into 4 groups. // These conditions can be where conditions, join conditions, or collection of both. // If deriveLeft/deriveRight is set, we would try to derive more conditions for left/right plan. -func extractOnCondition(conditions []expression.Expression, left LogicalPlan, right LogicalPlan, - deriveLeft bool, deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, +func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, deriveLeft bool, deriveRight bool) ( + eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { + left, right := p.children[0], p.children[1] for _, expr := range conditions { binop, ok := expr.(*expression.ScalarFunction) if ok && binop.FuncName.L == ast.EQ { @@ -205,6 +235,12 @@ func extractOnCondition(conditions []expression.Expression, left LogicalPlan, ri } } columns := expression.ExtractColumns(expr) + // `columns` may be empty, if the condition is like `correlated_column op constant`, or `constant`, + // push this kind of constant condition down according to join type. + if len(columns) == 0 { + leftCond, rightCond = p.pushDownConstExpr(expr, leftCond, rightCond, deriveLeft || deriveRight) + continue + } allFromLeft, allFromRight := true, true for _, col := range columns { if !left.Schema().Contains(col) { diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index a360c7378f4fb..f0b846e3cc50f 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -173,7 +173,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres } func (p *LogicalJoin) attachOnConds(onConds []expression.Expression) { - eq, left, right, other := extractOnCondition(onConds, p.children[0].(LogicalPlan), p.children[1].(LogicalPlan), false, false) + eq, left, right, other := p.extractOnCondition(onConds, false, false) p.EqualConditions = append(eq, p.EqualConditions...) p.LeftConditions = append(left, p.LeftConditions...) p.RightConditions = append(right, p.RightConditions...) diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index 5dd7ed040a87a..5f7569063a502 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -118,8 +118,6 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret newJoin := e.resultJoin return newJoin.PredicatePushDown(predicates) } - leftPlan := p.children[0] - rightPlan := p.children[1] var equalCond []*expression.ScalarFunction var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression switch p.JoinType { @@ -127,7 +125,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret // Handle where conditions predicates = expression.ExtractFiltersFromDNFs(p.ctx, predicates) // Only derive left where condition, because right where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(predicates, leftPlan, rightPlan, true, false) + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) leftCond = leftPushCond // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down _, derivedRightJoinCond := deriveOtherConditions(p, false, true) @@ -139,7 +137,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret // Handle where conditions predicates = expression.ExtractFiltersFromDNFs(p.ctx, predicates) // Only derive right where condition, because left where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(predicates, leftPlan, rightPlan, false, true) + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) rightCond = rightPushCond // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down derivedLeftJoinCond, _ := deriveOtherConditions(p, true, false) @@ -164,7 +162,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret return ret, dual } } - equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(tempCond, leftPlan, rightPlan, true, true) + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) p.LeftConditions = nil p.RightConditions = nil p.EqualConditions = equalCond @@ -172,8 +170,8 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret leftCond = leftPushCond rightCond = rightPushCond } - leftRet, lCh := leftPlan.PredicatePushDown(leftCond) - rightRet, rCh := rightPlan.PredicatePushDown(rightCond) + leftRet, lCh := p.children[0].PredicatePushDown(leftCond) + rightRet, rCh := p.children[1].PredicatePushDown(rightCond) addSelection(p, lCh, leftRet, 0) addSelection(p, rCh, rightRet, 1) p.updateEQCond()