Skip to content

Commit

Permalink
planner: maintain functional dependency for joins (pingcap#5)
Browse files Browse the repository at this point in the history
Co-authored-by: ailinkid <[email protected]>
  • Loading branch information
winoros and AilinKid authored Mar 1, 2022
1 parent 716584a commit d4ef9af
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 19 deletions.
4 changes: 2 additions & 2 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4269,10 +4269,10 @@ func (ds *DataSource) ExtractFD() *fd.FDSet {
notnullColsUniqueIDs := extractNotNullFromConds(ds.allConds, ds)

// extract the constant cols from selection conditions.
constUniqueIDs := extractConstantCols(ds.allConds, ds, fds)
constUniqueIDs := extractConstantCols(ds.allConds, ds.SCtx(), fds)

// extract equivalence cols.
equivUniqueIDs := extractEquivalenceCols(ds.allConds, ds, fds)
equivUniqueIDs := extractEquivalenceCols(ds.allConds, ds.SCtx(), fds)

// apply conditions to FD.
fds.MakeNotNull(notnullColsUniqueIDs)
Expand Down
121 changes: 107 additions & 14 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,99 @@ func (p *LogicalJoin) Shallow() *LogicalJoin {
return join.Init(p.ctx, p.blockOffset)
}

// ExtractFD implements the interface LogicalPlan.
func (p *LogicalJoin) ExtractFD() *fd.FDSet {
switch p.JoinType {
case InnerJoin:
return p.extractFDForInnerJoin()
case LeftOuterJoin, RightOuterJoin:
return p.extractFDForOuterJoin()
case SemiJoin:
return p.extractFDForSemiJoin()
default:
return &fd.FDSet{HashCodeToUniqueID: make(map[string]int)}
}
}

func (p *LogicalJoin) extractFDForSemiJoin() *fd.FDSet {
// 1: since semi join will keep the part or all rows of the outer table, it's outer FD can be saved.
// 2: the un-projected column will be left for the upper layer projection or already be pruned from bottom up.
outerFD, _ := p.children[0].ExtractFD(), p.children[1].ExtractFD()
fds := outerFD

eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
allConds := append(eqCondSlice, p.OtherConditions...)
notNullColsFromFilters := extractNotNullFromConds(allConds, p)

constUniqueIDs := extractConstantCols(p.LeftConditions, p.SCtx(), fds)

fds.MakeNotNull(notNullColsFromFilters)
fds.AddConstants(constUniqueIDs)
p.fdSet = fds
return fds
}

func (p *LogicalJoin) extractFDForInnerJoin() *fd.FDSet {
leftFD, rightFD := p.children[0].ExtractFD(), p.children[1].ExtractFD()
fds := leftFD
fds.MakeCartesianProduct(rightFD)

eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
allConds := append(eqCondSlice, p.OtherConditions...)
notNullColsFromFilters := extractNotNullFromConds(allConds, p)

constUniqueIDs := extractConstantCols(eqCondSlice, p.SCtx(), fds)

equivUniqueIDs := extractEquivalenceCols(eqCondSlice, p.SCtx(), fds)

fds.MakeNotNull(notNullColsFromFilters)
fds.AddConstants(constUniqueIDs)
for _, equiv := range equivUniqueIDs {
fds.AddEquivalence(equiv[0], equiv[1])
}
p.fdSet = fds
return fds
}

func (p *LogicalJoin) extractFDForOuterJoin() *fd.FDSet {
outerFD, innerFD := p.children[0].ExtractFD(), p.children[1].ExtractFD()
innerCondition := p.RightConditions
outerCols, innerCols := fd.NewFastIntSet(), fd.NewFastIntSet()
for _, col := range p.children[0].Schema().Columns {
outerCols.Insert(int(col.UniqueID))
}
for _, col := range p.children[1].Schema().Columns {
innerCols.Insert(int(col.UniqueID))
}
if p.JoinType == RightOuterJoin {
innerFD, outerFD = outerFD, innerFD
innerCondition = p.LeftConditions
innerCols, outerCols = outerCols, innerCols
}

eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
allConds := append(eqCondSlice, p.OtherConditions...)
allConds = append(allConds, innerCondition...)
notNullColsFromFilters := extractNotNullFromConds(allConds, p)

filterFD := &fd.FDSet{HashCodeToUniqueID: make(map[string]int)}

constUniqueIDs := extractConstantCols(eqCondSlice, p.SCtx(), filterFD)

equivUniqueIDs := extractEquivalenceCols(eqCondSlice, p.SCtx(), filterFD)

filterFD.AddConstants(constUniqueIDs)
for _, equiv := range equivUniqueIDs {
filterFD.AddEquivalence(equiv[0], equiv[1])
}
filterFD.MakeNotNull(notNullColsFromFilters)

fds := outerFD
fds.MakeOuterJoin(innerFD, filterFD, outerCols, innerCols)
p.fdSet = fds
return fds
}

// GetJoinKeys extracts join keys(columns) from EqualConditions. It returns left join keys, right
// join keys and an `isNullEQ` array which means the `joinKey[i]` is a `NullEQ` function. The `hasNullEQ`
// means whether there is a `NullEQ` of a join key.
Expand Down Expand Up @@ -657,34 +750,34 @@ func extractNotNullFromConds(Conditions []expression.Expression, p LogicalPlan)
return notnullColsUniqueIDs
}

func extractConstantCols(Conditions []expression.Expression, p LogicalPlan, fds *fd.FDSet) fd.FastIntSet {
func extractConstantCols(Conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) fd.FastIntSet {
// extract constant cols
// eg: where a=1 and b is null and (1+c)=5.
// TODO: Some columns can only be determined to be constant from multiple constraints (e.g. x <= 1 AND x >= 1)
var (
constObjs []expression.Expression
constUniqueIDs = fd.NewFastIntSet()
)
constObjs = expression.ExtractConstantEqColumnsOrScalar(p.SCtx(), constObjs, Conditions)
constObjs = expression.ExtractConstantEqColumnsOrScalar(sctx, constObjs, Conditions)
for _, constObj := range constObjs {
switch x := constObj.(type) {
case *expression.Column:
constUniqueIDs.Insert(int(x.UniqueID))
case *expression.ScalarFunction:
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
constUniqueIDs.Insert(uniqueID)
} else {
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
constUniqueIDs.Insert(scalarUniqueID)
}
}
}
return constUniqueIDs
}

func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, fds *fd.FDSet) [][]fd.FastIntSet {
func extractEquivalenceCols(Conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) [][]fd.FastIntSet {
var equivObjsPair [][]expression.Expression
equivObjsPair = expression.ExtractEquivalenceColumns(equivObjsPair, Conditions)
equivUniqueIDs := make([][]fd.FastIntSet, 0, len(equivObjsPair))
Expand All @@ -698,12 +791,12 @@ func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, f
case *expression.Column:
lhsUniqueID = int(x.UniqueID)
case *expression.ScalarFunction:
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
lhsUniqueID = uniqueID
} else {
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
lhsUniqueID = scalarUniqueID
}
}
Expand All @@ -712,12 +805,12 @@ func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, f
case *expression.Column:
rhsUniqueID = int(x.UniqueID)
case *expression.ScalarFunction:
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
rhsUniqueID = uniqueID
} else {
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
rhsUniqueID = scalarUniqueID
}
}
Expand All @@ -743,10 +836,10 @@ func (p *LogicalSelection) ExtractFD() *fd.FDSet {
notnullColsUniqueIDs.UnionWith(extractNotNullFromConds(p.Conditions, p))

// extract the constant cols from selection conditions.
constUniqueIDs := extractConstantCols(p.Conditions, p, fds)
constUniqueIDs := extractConstantCols(p.Conditions, p.SCtx(), fds)

// extract equivalence cols.
equivUniqueIDs := extractEquivalenceCols(p.Conditions, p, fds)
equivUniqueIDs := extractEquivalenceCols(p.Conditions, p.SCtx(), fds)

// apply operator's characteristic's FD setting.
fds.MakeNotNull(notnullColsUniqueIDs)
Expand Down
4 changes: 4 additions & 0 deletions planner/core/stringer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func fdToString(in LogicalPlan, strs []string, idxs []int) ([]string, []int) {
}
case *DataSource:
strs = append(strs, "{"+x.fdSet.String()+"}")
case *LogicalApply:
strs = append(strs, "{"+x.fdSet.String()+"}")
case *LogicalJoin:
strs = append(strs, "{"+x.fdSet.String()+"}")
default:
}
return strs, idxs
Expand Down
99 changes: 99 additions & 0 deletions planner/functional_dependency/extract_fd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,102 @@ func TestFDSet_ExtractFD(t *testing.T) {
ass.Equal(tt.fd, plannercore.FDToString(p.(plannercore.LogicalPlan)), comment)
}
}

func TestFDSet_ExtractFDForApply(t *testing.T) {
t.Parallel()
ass := assert.New(t)

store, clean := testkit.CreateMockStore(t)
defer clean()
par := parser.New()
par.SetParserConfig(parser.ParserConfig{EnableWindowFunction: true, EnableStrictDoubleTypeCheck: true})

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("CREATE TABLE X (a INT PRIMARY KEY, b INT, c INT, d INT, e INT)")
tk.MustExec("CREATE UNIQUE INDEX uni ON X (b, c)")
tk.MustExec("CREATE TABLE Y (m INT, n INT, p INT, q INT, PRIMARY KEY (m, n))")

tests := []struct {
sql string
best string
fd string
}{
{
sql: "select * from X where exists (select * from Y where m=a limit 1)",
// For this Apply, it's essentially a semi join, for every `a` in X, do the inner loop once.
// +- datasource(x)
// +- limit
// +- datasource(Y)
best: "Apply{DataScan(X)->DataScan(Y)->Limit}->Projection",
// Since semi join will keep the **all** rows of the outer table, it's FD can be derived.
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
},
{
sql: "select a, b from X where exists (select * from Y where m=a limit 1)",
// For this Apply, it's essentially a semi join, for every `a` in X, do the inner loop once.
// +- datasource(x)
// +- limit
// +- datasource(Y)
best: "Apply{DataScan(X)->DataScan(Y)->Limit}->Projection", // semi join
// Since semi join will keep the **part** rows of the outer table, it's FD can be derived.
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2)}",
},
{
// Limit will refuse to de-correlate apply to join while this won't.
sql: "select * from X where exists (select * from Y where m=a and p=1)",
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
},
{
sql: "select * from X where exists (select * from Y where m=a and p=q)",
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
},
{
sql: "select * from X where exists (select * from Y where m=a and b=1)",
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
// b=1 is semi join's left condition which should be conserved.
fd: "{(1)-->(3-5), (2,3)~~>(1,4,5), ()-->(2)} >>> {(1)-->(3-5), (2,3)~~>(1,4,5), ()-->(2)}",
},
{
sql: "select * from (select b,c,d,e from X) X1 where exists (select * from Y where p=q and n=1) ",
best: "Dual->Projection",
fd: "{}",
},
{
sql: "select * from (select b,c,d,e from X) X1 where exists (select * from Y where p=b and n=1) ",
best: "Join{DataScan(X)->DataScan(Y)}(test.x.b,test.y.p)->Projection", // semi join
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(2,3)~~>(4,5)}",
},
{
sql: "select * from X where exists (select m, p, q from Y where n=a and p=1)",
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.n)->Projection",
// p=1 is semi join's right condition which should **NOT** be conserved.
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
},
}

ctx := context.TODO()
is := testGetIS(ass, tk.Session())
for i, tt := range tests {
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
stmt, err := par.ParseOneStmt(tt.sql, "", "")
ass.Nil(err, comment)
tk.Session().GetSessionVars().PlanID = 0
tk.Session().GetSessionVars().PlanColumnID = 0
err = plannercore.Preprocess(tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is}))
ass.Nil(err)
tk.Session().PrepareTSFuture(ctx)
builder, _ := plannercore.NewPlanBuilder().Init(tk.Session(), is, &hint.BlockHintProcessor{})
// extract FD to every OP
p, err := builder.Build(ctx, stmt)
ass.Nil(err)
p, err = plannercore.LogicalOptimizeTest(ctx, builder.GetOptFlag(), p.(plannercore.LogicalPlan))
ass.Nil(err)
ass.Equal(tt.best, plannercore.ToString(p), comment)
// extract FD to every OP
p.(plannercore.LogicalPlan).ExtractFD()
ass.Equal(tt.fd, plannercore.FDToString(p.(plannercore.LogicalPlan)), comment)
}
}
45 changes: 43 additions & 2 deletions planner/functional_dependency/fd_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,49 @@ func (s *FDSet) MakeCartesianProduct(rhs *FDSet) {
}
}

func (s *FDSet) MakeLeftOuter(lhs, filterFDs *FDSet, lCols, rCols, notNullCols FastIntSet) {
// TODO:
// MakeApply maintain the FD relationship between outer and inner table after Apply OP is done.
// Since Apply is implemented by join, it seems the fd can be extracted through its inner join directly.
func (s *FDSet) MakeApply(inner *FDSet) {
}

// MakeOuterJoin generates the records the fdSet of the outer join.
// As we know, the outer join would generate null extended rows compared with inner join.
// So we cannot directly do the same thing with the inner join. This function deals with the special cases of the outer join.
func (s *FDSet) MakeOuterJoin(innerFDs, filterFDs *FDSet, outerCols, innerCols FastIntSet) {
for _, edge := range innerFDs.fdEdges {
// We don't maintain the equiv edges and lax edges currently.
if edge.equiv || !edge.strict {
continue
}
// If the one of the column from the inner child's functional dependency's left side is not null, this FD
// can be remained.
// This is because that the outer join would generate null-extended rows. So if at least one row from the left side
// is not null. We can guarantee that the there's no same part between the original rows and the generated rows.
// So the null extended rows would not break the original functional dependency.
if edge.from.SubsetOf(innerFDs.NotNullCols) {
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
} else if edge.from.SubsetOf(filterFDs.NotNullCols) {
// If we can make sure the filters of the join would filter out all nulls of this FD's left side
// and this FD is from the join's inner child. This FD can be remained.
// This is because the outer join filters out the null values. The generated null-extended rows would not
// find the same row from the original rows of the inner child. So it won't break the original functional dependency.
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
}
}
for _, edge := range filterFDs.fdEdges {
// We don't maintain the equiv edges and the lax edges currently.
if edge.equiv || !edge.strict {
continue
}
if edge.from.SubsetOf(innerCols) && edge.to.SubsetOf(innerCols) && edge.from.SubsetOf(filterFDs.NotNullCols) {
// The functional dependency generated from the join filter would be reserved if it meets the following conditions:
// 1. All columns from this functional dependency are the columns from the inner side.
// 2. The join keys can filter out the null values from the left side of the FD.
// This is the same with the above cases. If the join filters can filter out the null values of the FD's left side,
// We won't find a same row between the original rows of the inner side and the generated null-extended rows.
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
}
}
}

func (s FDSet) AllCols() FastIntSet {
Expand Down
6 changes: 5 additions & 1 deletion planner/functional_dependency/fd_graph_ported_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestFuncDeps_ColsAreKey(t *testing.T) {
loj = *abcde
loj.MakeCartesianProduct(mnpq)
loj.AddConstants(NewFastIntSet(3))
loj.MakeLeftOuter(abcde, &FDSet{}, preservedCols, nullExtendedCols, NewFastIntSet(1, 10, 11))
loj.MakeOuterJoin(nil, &FDSet{}, preservedCols, nullExtendedCols)
loj.AddEquivalence(NewFastIntSet(1), NewFastIntSet(10))

testcases := []struct {
Expand Down Expand Up @@ -330,3 +330,7 @@ func makeJoinFD(ass *assert.Assertions) *FDSet {
testColsAreLaxKey(ass, join, NewFastIntSet(2, 3, 11), join.AllCols(), false)
return join
}

func TestFuncDeps_OuterJoin(t *testing.T) {

}

0 comments on commit d4ef9af

Please sign in to comment.