Skip to content

Commit

Permalink
planner: introduce hashEquals interface for expression.Expression (#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
AilinKid authored Sep 9, 2024
1 parent b5ec2e3 commit 4ab1765
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 7 deletions.
36 changes: 31 additions & 5 deletions pkg/expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

var (
_ base.HashEquals = &Column{}
_ base.HashEquals = &CorrelatedColumn{}
)

// CorrelatedColumn stands for a column in a correlated sub query.
Expand Down Expand Up @@ -246,6 +247,31 @@ func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error
}, nil
}

// Hash64 implements HashEquals.<0th> interface.
func (col *CorrelatedColumn) Hash64(h base.Hasher) {
// correlatedColumn flag here is used to distinguish correlatedColumn and Column.
h.HashByte(correlatedColumn)
col.Column.Hash64(h)
// since col.Datum is filled in the runtime, we can't use it to calculate hash now, correlatedColumn flag + column is enough.
}

// Equals implements HashEquals.<1st> interface.
func (col *CorrelatedColumn) Equals(other any) bool {
if other == nil {
return false
}
var col2 *CorrelatedColumn
switch x := other.(type) {
case CorrelatedColumn:
col2 = &x
case *CorrelatedColumn:
col2 = x
default:
return false
}
return col.Column.Equals(&col2.Column)
}

// Column represents a column.
type Column struct {
RetType *types.FieldType `plan-cache-clone:"shallow"`
Expand Down Expand Up @@ -458,11 +484,11 @@ func (col *Column) Hash64(h base.Hasher) {
h.HashInt64(col.ID)
h.HashInt64(col.UniqueID)
h.HashInt(col.Index)
if col.VirtualExpr != nil {
if col.VirtualExpr == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
//col.VirtualExpr.Hash64(h)
col.VirtualExpr.Hash64(h)
}
h.HashString(col.OrigName)
h.HashBool(col.IsHidden)
Expand All @@ -488,12 +514,12 @@ func (col *Column) Equals(other any) bool {
}
// when step into here, we could ensure that col1.RetType and col2.RetType are same type.
// and we should ensure col1.RetType and col2.RetType is not nil ourselves.
ftEqual := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
return ftEqual &&
ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr))
return ok &&
col.ID == col2.ID &&
col.UniqueID == col2.UniqueID &&
col.Index == col2.Index &&
//col.VirtualExpr.Equals(col2.VirtualExpr) &&
col.OrigName == col2.OrigName &&
col.IsHidden == col2.IsHidden &&
col.IsPrefix == col2.IsPrefix &&
Expand Down
29 changes: 27 additions & 2 deletions pkg/expression/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,7 @@ func TestColumnHashEquals(t *testing.T) {
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))

// diff VirtualExpr
// TODO: add HashEquals for VirtualExpr
// diff VirtualExpr see TestColumnHashEuqals4VirtualExpr

// diff OrigName
col2.Index = col1.Index
Expand Down Expand Up @@ -468,3 +467,29 @@ func TestColumnHashEquals(t *testing.T) {
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))
}

func TestColumnHashEuqals4VirtualExpr(t *testing.T) {
col1 := &Column{UniqueID: 1, VirtualExpr: NewZero()}
col2 := &Column{UniqueID: 1, VirtualExpr: nil}
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
col1.Hash64(hasher1)
col2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, col1.Equals(col2))

col2.VirtualExpr = NewZero()
hasher2.Reset()
col2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, col1.Equals(col2))

col1.VirtualExpr = nil
col2.VirtualExpr = nil
hasher1.Reset()
hasher2.Reset()
col1.Hash64(hasher1)
col2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, col1.Equals(col2))
}
47 changes: 47 additions & 0 deletions pkg/expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

perrors "github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/codec"
Expand All @@ -29,6 +30,8 @@ import (
"go.uber.org/zap"
)

var _ base.HashEquals = &Constant{}

// NewOne stands for a number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
Expand Down Expand Up @@ -502,6 +505,50 @@ func (c *Constant) CanonicalHashCode() []byte {
return c.getHashCode(true)
}

// Hash64 implements HashEquals.<0th> interface.
func (c *Constant) Hash64(h base.Hasher) {
if c.RetType == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
c.RetType.Hash64(h)
}
c.collationInfo.Hash64(h)
if c.DeferredExpr != nil {
c.DeferredExpr.Hash64(h)
return
}
if c.ParamMarker != nil {
h.HashByte(parameterFlag)
h.HashInt64(int64(c.ParamMarker.order))
return
}
intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil)
h.HashByte(constantFlag)
c.Value.Hash64(h)
}

// Equals implements HashEquals.<1st> interface.
func (c *Constant) Equals(other any) bool {
if other == nil {
return false
}
var c2 *Constant
switch x := other.(type) {
case *Constant:
c2 = x
case Constant:
c2 = &x
default:
return false
}
ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
ok = ok && c.collationInfo.Equals(c2.collationInfo)
ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr))
ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order)
return ok && c.Value.Equals(c2.Value)
}

func (c *Constant) getHashCode(canonical bool) []byte {
if len(c.hashcode) > 0 {
return c.hashcode
Expand Down
28 changes: 28 additions & 0 deletions pkg/expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
Expand Down Expand Up @@ -545,3 +546,30 @@ func TestSpecificConstant(t *testing.T) {
require.Equal(t, null.RetType.GetFlen(), 1)
require.Equal(t, null.RetType.GetDecimal(), 0)
}

func TestConstantHashEquals(t *testing.T) {
// Test for Hash64 interface
cst1 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
cst2 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
cst1.Hash64(hasher1)
cst2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, cst1.Equals(cst2))

// test cst2 datum changes.
cst2.Value = types.NewIntDatum(2334)
hasher2.Reset()
cst2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, cst1.Equals(cst2))

// test cst2 type changes.
cst2.Value = types.NewIntDatum(2333)
cst2.RetType = newStringFieldType()
hasher2.Reset()
cst2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, cst1.Equals(cst2))
}
3 changes: 3 additions & 0 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/generatedexpr"
Expand All @@ -42,6 +43,7 @@ const (
scalarFunctionFlag byte = 3
parameterFlag byte = 4
ScalarSubQFlag byte = 5
correlatedColumn byte = 6
)

// EvalSimpleAst evaluates a simple ast expression directly.
Expand Down Expand Up @@ -170,6 +172,7 @@ const (
type Expression interface {
VecExpr
CollationInfo
base.HashEquals

Traverse(TraverseAction) Expression

Expand Down
48 changes: 48 additions & 0 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -35,6 +36,8 @@ import (
"github.com/pingcap/tidb/pkg/util/intest"
)

var _ base.HashEquals = &ScalarFunction{}

// ScalarFunction is the function that returns a value.
type ScalarFunction struct {
FuncName model.CIStr
Expand Down Expand Up @@ -673,6 +676,51 @@ func simpleCanonicalizedHashCode(sf *ScalarFunction) {
}
}

// Hash64 implements HashEquals.<0th> interface.
func (sf *ScalarFunction) Hash64(h base.Hasher) {
h.HashByte(scalarFunctionFlag)
h.HashString(sf.FuncName.L)
if sf.RetType == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
sf.RetType.Hash64(h)
}
// hash the arg length to avoid hash collision.
h.HashInt(len(sf.GetArgs()))
for _, arg := range sf.GetArgs() {
arg.Hash64(h)
}
}

// Equals implements HashEquals.<1th> interface.
func (sf *ScalarFunction) Equals(other any) bool {
if other == nil {
return false
}
var sf2 *ScalarFunction
switch x := other.(type) {
case *ScalarFunction:
sf2 = x
case ScalarFunction:
sf2 = &x
default:
return false
}
ok := sf.FuncName.L == sf2.FuncName.L
ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType))
if len(sf.GetArgs()) != len(sf2.GetArgs()) {
return false
}
for i, arg := range sf.GetArgs() {
ok = ok && arg.Equals(sf2.GetArgs()[i])
if !ok {
return false
}
}
return ok
}

// ReHashCode is used after we change the argument in place.
func ReHashCode(sf *ScalarFunction) {
sf.hashcode = sf.hashcode[:0]
Expand Down
38 changes: 38 additions & 0 deletions pkg/expression/scalar_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
Expand Down Expand Up @@ -147,3 +148,40 @@ func TestScalarFuncs2Exprs(t *testing.T) {
require.True(t, exprs[i].Equal(ctx, funcs[i]))
}
}

func TestScalarFunctionHash64Equals(t *testing.T) {
a := &Column{
UniqueID: 1,
RetType: types.NewFieldType(mysql.TypeDouble),
}
sf0, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
sf1, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
sf0.Hash64(hasher1)
sf1.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, sf0.Equals(sf1))

// change the func name
sf2, _ := newFunctionWithMockCtx(ast.GT, a, NewZero()).(*ScalarFunction)
hasher2.Reset()
sf2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf2))

// change the args
sf3, _ := newFunctionWithMockCtx(ast.LT, a, NewOne()).(*ScalarFunction)
hasher2.Reset()
sf3.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf3))

// change the ret type
sf4, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
sf4.RetType = types.NewFieldType(mysql.TypeLong)
hasher2.Reset()
sf4.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, sf0.Equals(sf4))
}
3 changes: 3 additions & 0 deletions pkg/expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -661,3 +662,5 @@ func (m *MockExpr) MemoryUsage() (sum int64) {
func (m *MockExpr) Traverse(action TraverseAction) Expression {
return action.Transform(m)
}
func (m *MockExpr) Hash64(_ base.Hasher) {}
func (m *MockExpr) Equals(_ any) bool { return false }
1 change: 1 addition & 0 deletions pkg/planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ go_library(
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/planner/cardinality",
"//pkg/planner/cascades/base",
"//pkg/planner/context",
"//pkg/planner/core/base",
"//pkg/planner/core/constraint",
Expand Down
Loading

0 comments on commit 4ab1765

Please sign in to comment.