diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index 38c4cf3790a63..41ad3637308ed 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -32,10 +32,10 @@ var MaxPropagateColsCnt = 100 // nolint:structcheck type basePropConstSolver struct { - colMapper map[int64]int // colMapper maps column to its index - eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] - unionSet *disjointset.IntSet // unionSet stores the relations like col_i = col_j - columns []*Column // columns stores all columns appearing in the conditions + colMapper map[int64]int // colMapper maps column to its index + eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] + unionSet *disjointset.SimpleIntSet // unionSet stores the relations like col_i = col_j + columns []*Column // columns stores all columns appearing in the conditions ctx exprctx.ExprContext } diff --git a/pkg/util/disjointset/BUILD.bazel b/pkg/util/disjointset/BUILD.bazel index 941410ed9d54b..8578cbc54206b 100644 --- a/pkg/util/disjointset/BUILD.bazel +++ b/pkg/util/disjointset/BUILD.bazel @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "disjointset", - srcs = ["int_set.go"], + srcs = [ + "int_set.go", + "set.go", + ], importpath = "github.com/pingcap/tidb/pkg/util/disjointset", visibility = ["//visibility:public"], ) @@ -13,6 +16,7 @@ go_test( srcs = [ "int_set_test.go", "main_test.go", + "set_test.go", ], embed = [":disjointset"], flaky = True, diff --git a/pkg/util/disjointset/int_set.go b/pkg/util/disjointset/int_set.go index 05846e3840850..a53b7e6d0a44a 100644 --- a/pkg/util/disjointset/int_set.go +++ b/pkg/util/disjointset/int_set.go @@ -14,27 +14,29 @@ package disjointset -// IntSet is the int disjoint set. -type IntSet struct { +// SimpleIntSet is the int disjoint set. +// It's not designed for sparse case. You should use it when the elements are continuous. +// Time complexity: the union operation is inverse ackermann function, which is very close to O(1). +type SimpleIntSet struct { parent []int } // NewIntSet returns a new int disjoint set. -func NewIntSet(size int) *IntSet { +func NewIntSet(size int) *SimpleIntSet { p := make([]int, size) for i := range p { p[i] = i } - return &IntSet{parent: p} + return &SimpleIntSet{parent: p} } // Union unions two sets in int disjoint set. -func (m *IntSet) Union(a int, b int) { +func (m *SimpleIntSet) Union(a int, b int) { m.parent[m.FindRoot(a)] = m.FindRoot(b) } // FindRoot finds the representative element of the set that `a` belongs to. -func (m *IntSet) FindRoot(a int) int { +func (m *SimpleIntSet) FindRoot(a int) int { if a == m.parent[a] { return a } diff --git a/pkg/util/disjointset/set.go b/pkg/util/disjointset/set.go new file mode 100644 index 0000000000000..08b63aec5dd3c --- /dev/null +++ b/pkg/util/disjointset/set.go @@ -0,0 +1,67 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package disjointset + +// Set is the universal implementation of a disjoint set. +// It's designed for sparse cases or non-integer types. +// If you are dealing with continuous integers, you should use SimpleIntSet to avoid the cost of a hash map. +// We hash the original value to an integer index and then apply the core disjoint set algorithm. +// Time complexity: the union operation has an inverse Ackermann function time complexity, which is very close to O(1). +type Set[T comparable] struct { + parent []int + val2Idx map[T]int + tailIdx int +} + +// NewSet creates a disjoint set. +func NewSet[T comparable](size int) *Set[T] { + return &Set[T]{ + parent: make([]int, 0, size), + val2Idx: make(map[T]int, size), + tailIdx: 0, + } +} +func (s *Set[T]) findRootOriginalVal(a T) int { + idx, ok := s.val2Idx[a] + if !ok { + s.parent = append(s.parent, s.tailIdx) + s.val2Idx[a] = s.tailIdx + s.tailIdx++ + return s.tailIdx - 1 + } + return s.findRoot(idx) +} + +// findRoot is an internal implementation. Call it inside findRootOriginalVal. +func (s *Set[T]) findRoot(a int) int { + if s.parent[a] != a { + s.parent[a] = s.findRoot(s.parent[a]) + } + return s.parent[a] +} + +// InSameGroup checks whether a and b are in the same group. +func (s *Set[T]) InSameGroup(a, b T) bool { + return s.findRootOriginalVal(a) == s.findRootOriginalVal(b) +} + +// Union joins two sets in the disjoint set. +func (s *Set[T]) Union(a, b T) { + rootA := s.findRootOriginalVal(a) + rootB := s.findRootOriginalVal(b) + if rootA != rootB { + s.parent[rootA] = rootB + } +} diff --git a/pkg/util/disjointset/set_test.go b/pkg/util/disjointset/set_test.go new file mode 100644 index 0000000000000..ae7cada175845 --- /dev/null +++ b/pkg/util/disjointset/set_test.go @@ -0,0 +1,49 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package disjointset + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDisjointSet(t *testing.T) { + set := NewSet[string](10) + assert.False(t, set.InSameGroup("a", "b")) + assert.Len(t, set.parent, 2) + set.Union("a", "b") + assert.True(t, set.InSameGroup("a", "b")) + assert.False(t, set.InSameGroup("a", "c")) + assert.Len(t, set.parent, 3) + assert.False(t, set.InSameGroup("b", "c")) + assert.Len(t, set.parent, 3) + set.Union("b", "c") + assert.True(t, set.InSameGroup("a", "c")) + assert.True(t, set.InSameGroup("b", "c")) + set.Union("d", "e") + set.Union("e", "f") + set.Union("f", "g") + assert.Len(t, set.parent, 7) + assert.False(t, set.InSameGroup("a", "d")) + assert.True(t, set.InSameGroup("d", "g")) + assert.False(t, set.InSameGroup("c", "g")) + set.Union("a", "g") + assert.True(t, set.InSameGroup("a", "d")) + assert.True(t, set.InSameGroup("b", "g")) + assert.True(t, set.InSameGroup("c", "f")) + assert.True(t, set.InSameGroup("a", "e")) + assert.True(t, set.InSameGroup("b", "c")) +}