Skip to content

Commit

Permalink
disjoinset: add generic impl (pingcap#54917)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored and hawkingrei committed Aug 1, 2024
1 parent 7dc9d3d commit 960b0c8
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 11 deletions.
8 changes: 4 additions & 4 deletions pkg/expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ var MaxPropagateColsCnt = 100

// nolint:structcheck
type basePropConstSolver struct {
colMapper map[int64]int // colMapper maps column to its index
eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i]
unionSet *disjointset.IntSet // unionSet stores the relations like col_i = col_j
columns []*Column // columns stores all columns appearing in the conditions
colMapper map[int64]int // colMapper maps column to its index
eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i]
unionSet *disjointset.SimpleIntSet // unionSet stores the relations like col_i = col_j
columns []*Column // columns stores all columns appearing in the conditions
ctx exprctx.ExprContext
}

Expand Down
6 changes: 5 additions & 1 deletion pkg/util/disjointset/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "disjointset",
srcs = ["int_set.go"],
srcs = [
"int_set.go",
"set.go",
],
importpath = "github.com/pingcap/tidb/pkg/util/disjointset",
visibility = ["//visibility:public"],
)
Expand All @@ -13,6 +16,7 @@ go_test(
srcs = [
"int_set_test.go",
"main_test.go",
"set_test.go",
],
embed = [":disjointset"],
flaky = True,
Expand Down
14 changes: 8 additions & 6 deletions pkg/util/disjointset/int_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,29 @@

package disjointset

// IntSet is the int disjoint set.
type IntSet struct {
// SimpleIntSet is the int disjoint set.
// It's not designed for sparse case. You should use it when the elements are continuous.
// Time complexity: the union operation is inverse ackermann function, which is very close to O(1).
type SimpleIntSet struct {
parent []int
}

// NewIntSet returns a new int disjoint set.
func NewIntSet(size int) *IntSet {
func NewIntSet(size int) *SimpleIntSet {
p := make([]int, size)
for i := range p {
p[i] = i
}
return &IntSet{parent: p}
return &SimpleIntSet{parent: p}
}

// Union unions two sets in int disjoint set.
func (m *IntSet) Union(a int, b int) {
func (m *SimpleIntSet) Union(a int, b int) {
m.parent[m.FindRoot(a)] = m.FindRoot(b)
}

// FindRoot finds the representative element of the set that `a` belongs to.
func (m *IntSet) FindRoot(a int) int {
func (m *SimpleIntSet) FindRoot(a int) int {
if a == m.parent[a] {
return a
}
Expand Down
67 changes: 67 additions & 0 deletions pkg/util/disjointset/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package disjointset

// Set is the universal implementation of a disjoint set.
// It's designed for sparse cases or non-integer types.
// If you are dealing with continuous integers, you should use SimpleIntSet to avoid the cost of a hash map.
// We hash the original value to an integer index and then apply the core disjoint set algorithm.
// Time complexity: the union operation has an inverse Ackermann function time complexity, which is very close to O(1).
type Set[T comparable] struct {
parent []int
val2Idx map[T]int
tailIdx int
}

// NewSet creates a disjoint set.
func NewSet[T comparable](size int) *Set[T] {
return &Set[T]{
parent: make([]int, 0, size),
val2Idx: make(map[T]int, size),
tailIdx: 0,
}
}
func (s *Set[T]) findRootOriginalVal(a T) int {
idx, ok := s.val2Idx[a]
if !ok {
s.parent = append(s.parent, s.tailIdx)
s.val2Idx[a] = s.tailIdx
s.tailIdx++
return s.tailIdx - 1
}
return s.findRoot(idx)
}

// findRoot is an internal implementation. Call it inside findRootOriginalVal.
func (s *Set[T]) findRoot(a int) int {
if s.parent[a] != a {
s.parent[a] = s.findRoot(s.parent[a])
}
return s.parent[a]
}

// InSameGroup checks whether a and b are in the same group.
func (s *Set[T]) InSameGroup(a, b T) bool {
return s.findRootOriginalVal(a) == s.findRootOriginalVal(b)
}

// Union joins two sets in the disjoint set.
func (s *Set[T]) Union(a, b T) {
rootA := s.findRootOriginalVal(a)
rootB := s.findRootOriginalVal(b)
if rootA != rootB {
s.parent[rootA] = rootB
}
}
49 changes: 49 additions & 0 deletions pkg/util/disjointset/set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package disjointset

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestDisjointSet(t *testing.T) {
set := NewSet[string](10)
assert.False(t, set.InSameGroup("a", "b"))
assert.Len(t, set.parent, 2)
set.Union("a", "b")
assert.True(t, set.InSameGroup("a", "b"))
assert.False(t, set.InSameGroup("a", "c"))
assert.Len(t, set.parent, 3)
assert.False(t, set.InSameGroup("b", "c"))
assert.Len(t, set.parent, 3)
set.Union("b", "c")
assert.True(t, set.InSameGroup("a", "c"))
assert.True(t, set.InSameGroup("b", "c"))
set.Union("d", "e")
set.Union("e", "f")
set.Union("f", "g")
assert.Len(t, set.parent, 7)
assert.False(t, set.InSameGroup("a", "d"))
assert.True(t, set.InSameGroup("d", "g"))
assert.False(t, set.InSameGroup("c", "g"))
set.Union("a", "g")
assert.True(t, set.InSameGroup("a", "d"))
assert.True(t, set.InSameGroup("b", "g"))
assert.True(t, set.InSameGroup("c", "f"))
assert.True(t, set.InSameGroup("a", "e"))
assert.True(t, set.InSameGroup("b", "c"))
}

0 comments on commit 960b0c8

Please sign in to comment.