From 0116b993029acef554e90b42eb0b04fdb8bd1a66 Mon Sep 17 00:00:00 2001 From: Arenatlx <314806019@qq.com> Date: Fri, 13 Sep 2024 18:24:10 +0800 Subject: [PATCH] planner: add group/memo/groupExpression. (#55825) ref pingcap/tidb#51664 --- pkg/planner/cascades/memo/BUILD.bazel | 39 ++++++ pkg/planner/cascades/memo/group.go | 115 ++++++++++++++++++ .../cascades/memo/group_and_expr_test.go | 75 ++++++++++++ pkg/planner/cascades/memo/group_expr.go | 106 ++++++++++++++++ .../cascades/memo/group_id_generator.go | 30 +++++ .../cascades/memo/group_id_generator_test.go | 48 ++++++++ pkg/planner/cascades/memo/memo.go | 107 ++++++++++++++++ .../operator/logicalop/logical_projection.go | 42 +++---- 8 files changed, 537 insertions(+), 25 deletions(-) create mode 100644 pkg/planner/cascades/memo/BUILD.bazel create mode 100644 pkg/planner/cascades/memo/group.go create mode 100644 pkg/planner/cascades/memo/group_and_expr_test.go create mode 100644 pkg/planner/cascades/memo/group_expr.go create mode 100644 pkg/planner/cascades/memo/group_id_generator.go create mode 100644 pkg/planner/cascades/memo/group_id_generator_test.go create mode 100644 pkg/planner/cascades/memo/memo.go diff --git a/pkg/planner/cascades/memo/BUILD.bazel b/pkg/planner/cascades/memo/BUILD.bazel new file mode 100644 index 0000000000000..88670682bca8a --- /dev/null +++ b/pkg/planner/cascades/memo/BUILD.bazel @@ -0,0 +1,39 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "memo", + srcs = [ + "group.go", + "group_expr.go", + "group_id_generator.go", + "memo.go", + ], + importpath = "github.com/pingcap/tidb/pkg/planner/cascades/memo", + visibility = ["//visibility:public"], + deps = [ + "//pkg/planner/cascades/base", + "//pkg/planner/core/base", + "//pkg/planner/pattern", + "//pkg/planner/property", + "//pkg/sessionctx", + "//pkg/util/intest", + ], +) + +go_test( + name = "memo_test", + timeout = "short", + srcs = [ + "group_and_expr_test.go", + "group_id_generator_test.go", + ], + embed = [":memo"], + flaky = True, + shard_count = 3, + deps = [ + "//pkg/expression", + "//pkg/planner/cascades/base", + "//pkg/planner/core/operator/logicalop", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/planner/cascades/memo/group.go b/pkg/planner/cascades/memo/group.go new file mode 100644 index 0000000000000..c5a387a1fc861 --- /dev/null +++ b/pkg/planner/cascades/memo/group.go @@ -0,0 +1,115 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + "container/list" + + "github.com/pingcap/tidb/pkg/planner/cascades/base" + "github.com/pingcap/tidb/pkg/planner/pattern" + "github.com/pingcap/tidb/pkg/planner/property" +) + +var _ base.HashEquals = &Group{} + +// Group is basic infra to store all the logically equivalent expressions +// for one logical operator in current context. +type Group struct { + // groupID indicates the uniqueness of this group, also for encoding. + groupID GroupID + + // logicalExpressions indicates the logical equiv classes for this group. + logicalExpressions *list.List + + // operand2FirstExpr is used to locate to the first same type logical expression + // in list above instead of traverse them all. + operand2FirstExpr map[pattern.Operand]*list.Element + + // hash2GroupExpr is used to de-duplication in the list. + hash2GroupExpr map[uint64]*list.Element + + // logicalProp indicates the logical property. + logicalProp *property.LogicalProperty + + // explored indicates whether this group has been explored. + explored bool +} + +// ******************************************* start of HashEqual methods ******************************************* + +// Hash64 implements the HashEquals.<0th> interface. +func (g *Group) Hash64(h base.Hasher) { + h.HashUint64(uint64(g.groupID)) +} + +// Equals implements the HashEquals.<1st> interface. +func (g *Group) Equals(other any) bool { + if other == nil { + return false + } + switch x := other.(type) { + case *Group: + return g.groupID == x.groupID + case Group: + return g.groupID == x.groupID + default: + return false + } +} + +// ******************************************* end of HashEqual methods ******************************************* + +// Exists checks whether a Group expression existed in a Group. +func (g *Group) Exists(hash64u uint64) bool { + _, ok := g.hash2GroupExpr[hash64u] + return ok +} + +// Insert adds a GroupExpression to the Group. +func (g *Group) Insert(e *GroupExpression) bool { + if e == nil { + return false + } + // GroupExpressions hash should be initialized within Init(xxx) method. + hash64 := e.Sum64() + if g.Exists(hash64) { + return false + } + operand := pattern.GetOperand(e.logicalPlan) + var newEquiv *list.Element + mark, ok := g.operand2FirstExpr[operand] + if ok { + // cluster same operands together. + newEquiv = g.logicalExpressions.InsertAfter(e, mark) + } else { + // otherwise, put it at the end. + newEquiv = g.logicalExpressions.PushBack(e) + g.operand2FirstExpr[operand] = newEquiv + } + g.hash2GroupExpr[hash64] = newEquiv + e.group = g + return true +} + +// NewGroup creates a new Group with given logical prop. +func NewGroup(prop *property.LogicalProperty) *Group { + g := &Group{ + logicalExpressions: list.New(), + hash2GroupExpr: make(map[uint64]*list.Element), + operand2FirstExpr: make(map[pattern.Operand]*list.Element), + logicalProp: prop, + } + return g +} diff --git a/pkg/planner/cascades/memo/group_and_expr_test.go b/pkg/planner/cascades/memo/group_and_expr_test.go new file mode 100644 index 0000000000000..4fd379964cbda --- /dev/null +++ b/pkg/planner/cascades/memo/group_and_expr_test.go @@ -0,0 +1,75 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/planner/cascades/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/stretchr/testify/require" +) + +func TestGroupHashEquals(t *testing.T) { + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + a := Group{groupID: 1} + b := Group{groupID: 1} + a.Hash64(hasher1) + b.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, a.Equals(b)) + require.True(t, a.Equals(&b)) + + // change the id. + b.groupID = 2 + hasher2.Reset() + b.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, a.Equals(b)) + require.False(t, a.Equals(&b)) +} + +func TestGroupExpressionHashEquals(t *testing.T) { + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + child1 := &Group{groupID: 1} + child2 := &Group{groupID: 2} + a := GroupExpression{ + group: &Group{groupID: 3}, + inputs: []*Group{child1, child2}, + logicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}}, + } + b := GroupExpression{ + // root group should change the hash. + group: &Group{groupID: 4}, + inputs: []*Group{child1, child2}, + logicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}}, + } + a.Hash64(hasher1) + b.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, a.Equals(b)) + require.True(t, a.Equals(&b)) + + // change the children order, like join commutative. + b.inputs = []*Group{child2, child1} + hasher2.Reset() + b.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, a.Equals(b)) + require.False(t, a.Equals(&b)) +} diff --git a/pkg/planner/cascades/memo/group_expr.go b/pkg/planner/cascades/memo/group_expr.go new file mode 100644 index 0000000000000..97ca867ad1fef --- /dev/null +++ b/pkg/planner/cascades/memo/group_expr.go @@ -0,0 +1,106 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + base2 "github.com/pingcap/tidb/pkg/planner/cascades/base" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/pattern" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// GroupExpression is a single expression from the equivalent list classes inside a group. +// it is a node in the expression tree, while it takes groups as inputs. This kind of loose +// coupling between Group and GroupExpression is the key to the success of the memory compact +// of representing a forest. +type GroupExpression struct { + // group is the Group that this GroupExpression belongs to. + group *Group + + // inputs stores the Groups that this GroupExpression based on. + inputs []*Group + + // logicalPlan is internal logical expression stands for this groupExpr. + logicalPlan base.LogicalPlan + + // hash64 is the unique fingerprint of the GroupExpression. + hash64 uint64 +} + +// Sum64 returns the cached hash64 of the GroupExpression. +func (e *GroupExpression) Sum64() uint64 { + intest.Assert(e.hash64 != 0, "hash64 should not be 0") + return e.hash64 +} + +// Hash64 implements the Hash64 interface. +func (e *GroupExpression) Hash64(h base2.Hasher) { + // logical plan hash. + e.logicalPlan.Hash64(h) + // children group hash. + for _, child := range e.inputs { + child.Hash64(h) + } +} + +// Equals implements the Equals interface. +func (e *GroupExpression) Equals(other any) bool { + if other == nil { + return false + } + var e2 *GroupExpression + switch x := other.(type) { + case *GroupExpression: + e2 = x + case GroupExpression: + e2 = &x + default: + return false + } + if len(e.inputs) != len(e2.inputs) { + return false + } + if pattern.GetOperand(e.logicalPlan) != pattern.GetOperand(e2.logicalPlan) { + return false + } + // current logical operator meta cmp, logical plan don't care logicalPlan's children. + // when we convert logicalPlan to GroupExpression, we will set children to nil. + if !e.logicalPlan.Equals(e2.logicalPlan) { + return false + } + // if one of the children is different, then the two GroupExpressions are different. + for i, one := range e.inputs { + if !one.Equals(e2.inputs[i]) { + return false + } + } + return true +} + +// NewGroupExpression creates a new GroupExpression with the given logical plan and children. +func NewGroupExpression(lp base.LogicalPlan, inputs []*Group) *GroupExpression { + return &GroupExpression{ + group: nil, + inputs: inputs, + logicalPlan: lp, + hash64: 0, + } +} + +// Init initializes the GroupExpression with the given group and hasher. +func (e *GroupExpression) Init(h base2.Hasher) { + e.Hash64(h) + e.hash64 = h.Sum64() +} diff --git a/pkg/planner/cascades/memo/group_id_generator.go b/pkg/planner/cascades/memo/group_id_generator.go new file mode 100644 index 0000000000000..32aa90bfefc1c --- /dev/null +++ b/pkg/planner/cascades/memo/group_id_generator.go @@ -0,0 +1,30 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +// GroupID is the unique id for a group. +type GroupID uint64 + +// GroupIDGenerator is used to generate group id. +type GroupIDGenerator struct { + id uint64 +} + +// NextGroupID generates the next group id. +// It is not thread-safe, since memo optimizing is also in one thread. +func (gi *GroupIDGenerator) NextGroupID() GroupID { + gi.id++ + return GroupID(gi.id) +} diff --git a/pkg/planner/cascades/memo/group_id_generator_test.go b/pkg/planner/cascades/memo/group_id_generator_test.go new file mode 100644 index 0000000000000..fb2a644ca3a8b --- /dev/null +++ b/pkg/planner/cascades/memo/group_id_generator_test.go @@ -0,0 +1,48 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGroupIDGenerator_NextGroupID(t *testing.T) { + g := GroupIDGenerator{} + got := g.NextGroupID() + require.Equal(t, GroupID(1), got) + got = g.NextGroupID() + require.Equal(t, GroupID(2), got) + got = g.NextGroupID() + require.Equal(t, GroupID(3), got) + + // adjust the id. + g.id = 100 + got = g.NextGroupID() + require.Equal(t, GroupID(101), got) + got = g.NextGroupID() + require.Equal(t, GroupID(102), got) + got = g.NextGroupID() + require.Equal(t, GroupID(103), got) + + g.id = math.MaxUint64 + got = g.NextGroupID() + // rewire to 0. + require.Equal(t, GroupID(0), got) + got = g.NextGroupID() + require.Equal(t, GroupID(1), got) +} diff --git a/pkg/planner/cascades/memo/memo.go b/pkg/planner/cascades/memo/memo.go new file mode 100644 index 0000000000000..89758ee0e63e2 --- /dev/null +++ b/pkg/planner/cascades/memo/memo.go @@ -0,0 +1,107 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + "container/list" + "sync" + + base2 "github.com/pingcap/tidb/pkg/planner/cascades/base" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// Memo is the main structure of the memo package. +type Memo struct { + // ctx is the context of the memo. + sCtx sessionctx.Context + + // groupIDGen is the incremental group id for internal usage. + groupIDGen GroupIDGenerator + + // rootGroup is the root group of the memo. + rootGroup *Group + + // groups is the list of all groups in the memo. + groups *list.List + + // groupID2Group is the map from group id to group. + groupID2Group map[GroupID]*list.Element + + // hash2GroupExpr is the map from hash to group expression. + hash2GroupExpr map[uint64]*list.Element + + // hasherPool is the pool of hasher. + hasherPool *sync.Pool +} + +// NewMemo creates a new memo. +func NewMemo(ctx sessionctx.Context) *Memo { + return &Memo{ + sCtx: ctx, + groupIDGen: GroupIDGenerator{id: 0}, + groups: list.New(), + groupID2Group: make(map[GroupID]*list.Element), + hasherPool: &sync.Pool{New: func() any { return base2.NewHashEqualer() }}, + } +} + +// CopyIn copies a logical plan into the memo with format as GroupExpression. +func (m *Memo) CopyIn(target *Group, lp base.LogicalPlan) (*GroupExpression, bool) { + // Group the children first. + childGroups := make([]*Group, 0, len(lp.Children())) + for _, child := range lp.Children() { + // todo: child.getGroupExpression.GetGroup directly + groupExpr, ok := m.CopyIn(nil, child) + group := groupExpr.group + intest.Assert(ok) + intest.Assert(group != nil) + intest.Assert(group != target) + childGroups = append(childGroups, group) + } + + hasher := m.hasherPool.Get().(base2.Hasher) + hasher.Reset() + groupExpr := NewGroupExpression(lp, childGroups) + groupExpr.Init(hasher) + m.hasherPool.Put(hasher) + + ok := m.insertGroupExpression(groupExpr, target) + // todo: new group need to derive the logical property. + return groupExpr, ok +} + +// @bool indicates whether the groupExpr is inserted to a new group. +func (m *Memo) insertGroupExpression(groupExpr *GroupExpression, target *Group) bool { + // for group merge, here groupExpr is the new groupExpr with undetermined belonged group. + // we need to use groupExpr hash to find whether there is same groupExpr existed before. + // if existed and the existed groupExpr.Group is not same with target, we should merge them up. + // todo: merge group + if target == nil { + target = m.NewGroup() + m.groups.PushBack(target) + m.groupID2Group[target.groupID] = m.groups.Back() + } + target.Insert(groupExpr) + return true +} + +// NewGroup creates a new group. +func (m *Memo) NewGroup() *Group { + group := NewGroup(nil) + group.groupID = m.groupIDGen.NextGroupID() + return group +} diff --git a/pkg/planner/core/operator/logicalop/logical_projection.go b/pkg/planner/core/operator/logicalop/logical_projection.go index dc3d9a6b990d0..15bc95a0f9448 100644 --- a/pkg/planner/core/operator/logicalop/logical_projection.go +++ b/pkg/planner/core/operator/logicalop/logical_projection.go @@ -60,18 +60,10 @@ func (p LogicalProjection) Init(ctx base.PlanContext, qbOffset int) *LogicalProj // Hash64 implements the base.Hash64.<0th> interface. func (p *LogicalProjection) Hash64(h base2.Hasher) { - // todo: LogicalSchemaProducer should implement HashEquals interface, otherwise, its self elements - // like schema and names are lost. - p.LogicalSchemaProducer.Hash64(h) - // todo: if we change the logicalProjection's Expr definition as:Exprs []memo.ScalarOperator[any], - // we should use like below: - // for _, one := range p.Exprs { - // one.Hash64(one) - // } - // otherwise, we would use the belowing code. - //for _, one := range p.Exprs { - // one.Hash64(h) - //} + h.HashInt(len(p.Exprs)) + for _, one := range p.Exprs { + one.Hash64(h) + } h.HashBool(p.CalculateNoDelay) h.HashBool(p.Proj4Expand) } @@ -81,24 +73,24 @@ func (p *LogicalProjection) Equals(other any) bool { if other == nil { return false } - proj, ok := other.(*LogicalProjection) - if !ok { + var p2 *LogicalProjection + switch x := other.(type) { + case *LogicalProjection: + p2 = x + case LogicalProjection: + p2 = &x + default: return false } - // todo: LogicalSchemaProducer should implement HashEquals interface, otherwise, its self elements - // like schema and names are lost. - if !p.LogicalSchemaProducer.Equals(&proj.LogicalSchemaProducer) { + if len(p.Exprs) != len(p2.Exprs) { return false } - //for i, one := range p.Exprs { - // if !one.(memo.ScalarOperator[any]).Equals(other.Exprs[i]) { - // return false - // } - //} - if p.CalculateNoDelay != proj.CalculateNoDelay { - return false + for i, one := range p.Exprs { + if !one.Equals(p2.Exprs[i]) { + return false + } } - return p.Proj4Expand == proj.Proj4Expand + return p.CalculateNoDelay == p2.CalculateNoDelay && p.Proj4Expand == p2.Proj4Expand } // *************************** start implementation of Plan interface **********************************