Skip to content

Commit

Permalink
util: move disjoint set to util package (#7950)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored and eurekaka committed Oct 18, 2018
1 parent 19f5648 commit bfc12cd
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 35 deletions.
46 changes: 11 additions & 35 deletions expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,19 @@ import (
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/disjointset"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)

// MaxPropagateColsCnt means the max number of columns that can participate propagation.
var MaxPropagateColsCnt = 100

type multiEqualSet struct {
parent []int
}

func (m *multiEqualSet) init(l int) {
m.parent = make([]int, l)
for i := range m.parent {
m.parent[i] = i
}
}

func (m *multiEqualSet) addRelation(a int, b int) {
m.parent[m.findRoot(a)] = m.findRoot(b)
}

func (m *multiEqualSet) findRoot(a int) int {
if a == m.parent[a] {
return a
}
m.parent[a] = m.findRoot(m.parent[a])
return m.parent[a]
}

type basePropConstSolver struct {
colMapper map[int64]int // colMapper maps column to its index
eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i]
unionSet *multiEqualSet // unionSet stores the relations like col_i = col_j
columns []*Column // columns stores all columns appearing in the conditions
colMapper map[int64]int // colMapper maps column to its index
eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i]
unionSet *disjointset.IntSet // unionSet stores the relations like col_i = col_j
columns []*Column // columns stores all columns appearing in the conditions
ctx sessionctx.Context
}

Expand Down Expand Up @@ -208,16 +186,15 @@ func (s *propConstSolver) propagateConstantEQ() {
// We maintain a unionSet representing the equivalent for every two columns.
func (s *propConstSolver) propagateColumnEQ() {
visited := make([]bool, len(s.conditions))
s.unionSet = &multiEqualSet{}
s.unionSet.init(len(s.columns))
s.unionSet = disjointset.NewIntSet(len(s.columns))
for i := range s.conditions {
if fun, ok := s.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ {
lCol, lOk := fun.GetArgs()[0].(*Column)
rCol, rOk := fun.GetArgs()[1].(*Column)
if lOk && rOk {
lID := s.getColID(lCol)
rID := s.getColID(rCol)
s.unionSet.addRelation(lID, rID)
s.unionSet.Union(lID, rID)
visited[i] = true
}
}
Expand All @@ -227,7 +204,7 @@ func (s *propConstSolver) propagateColumnEQ() {
for i, coli := range s.columns {
for j := i + 1; j < len(s.columns); j++ {
// unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation
if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) {
if s.unionSet.FindRoot(i) != s.unionSet.FindRoot(j) {
continue
}
colj := s.columns[j]
Expand Down Expand Up @@ -489,8 +466,7 @@ func (s *propOuterJoinConstSolver) deriveConds(outerCol, innerCol *Column, schem
// Derived new expressions must be appended into join condition, not filter condition.
func (s *propOuterJoinConstSolver) propagateColumnEQ() {
visited := make([]bool, len(s.joinConds)+len(s.filterConds))
s.unionSet = &multiEqualSet{}
s.unionSet.init(len(s.columns))
s.unionSet = disjointset.NewIntSet(len(s.columns))
var outerCol, innerCol *Column
// Only consider column equal condition in joinConds.
// If we have column equal in filter condition, the outer join should have been simplified already.
Expand All @@ -499,7 +475,7 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() {
if outerCol != nil {
outerID := s.getColID(outerCol)
innerID := s.getColID(innerCol)
s.unionSet.addRelation(outerID, innerID)
s.unionSet.Union(outerID, innerID)
visited[i] = true
}
}
Expand All @@ -508,7 +484,7 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() {
for i, coli := range s.columns {
for j := i + 1; j < len(s.columns); j++ {
// unionSet doesn't have iterate(), we use a two layer loop to iterate col_i = col_j relation.
if s.unionSet.findRoot(i) != s.unionSet.findRoot(j) {
if s.unionSet.FindRoot(i) != s.unionSet.FindRoot(j) {
continue
}
colj := s.columns[j]
Expand Down
42 changes: 42 additions & 0 deletions util/disjointset/int_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package disjointset

// IntSet is the int disjoint set.
type IntSet struct {
parent []int
}

// NewIntSet returns a new int disjoint set.
func NewIntSet(size int) *IntSet {
p := make([]int, size)
for i := range p {
p[i] = i
}
return &IntSet{parent: p}
}

// Union unions two sets in int disjoint set.
func (m *IntSet) Union(a int, b int) {
m.parent[m.FindRoot(a)] = m.FindRoot(b)
}

// FindRoot finds the representative element of the set that `a` belongs to.
func (m *IntSet) FindRoot(a int) int {
if a == m.parent[a] {
return a
}
m.parent[a] = m.FindRoot(m.parent[a])
return m.parent[a]
}
52 changes: 52 additions & 0 deletions util/disjointset/int_set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package disjointset

import (
"testing"

. "github.com/pingcap/check"
)

var _ = Suite(&testDisjointSetSuite{})

func TestT(t *testing.T) {
CustomVerboseFlag = true
TestingT(t)
}

type testDisjointSetSuite struct {
}

func (s *testDisjointSetSuite) TestIntDisjointSet(c *C) {
set := NewIntSet(10)
c.Assert(len(set.parent), Equals, 10)
for i := range set.parent {
c.Assert(set.parent[i], Equals, i)
}
set.Union(0, 1)
set.Union(1, 3)
set.Union(4, 2)
set.Union(2, 6)
set.Union(3, 5)
set.Union(7, 8)
set.Union(9, 6)
c.Assert(set.FindRoot(0), Equals, set.FindRoot(1))
c.Assert(set.FindRoot(3), Equals, set.FindRoot(1))
c.Assert(set.FindRoot(5), Equals, set.FindRoot(1))
c.Assert(set.FindRoot(2), Equals, set.FindRoot(4))
c.Assert(set.FindRoot(6), Equals, set.FindRoot(4))
c.Assert(set.FindRoot(9), Equals, set.FindRoot(2))
c.Assert(set.FindRoot(7), Equals, set.FindRoot(8))
}

0 comments on commit bfc12cd

Please sign in to comment.