From 1a24c032126dce4a79ea14b51108dbf4caf02a03 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Wed, 19 Jun 2024 20:10:17 +0800 Subject: [PATCH] expression: correct the erroneous scalar function equivalence check (#54067) close pingcap/tidb#53726 --- pkg/expression/scalar_function.go | 3 ++ pkg/expression/util_test.go | 2 +- pkg/parser/types/field_type.go | 2 +- .../core/issuetest/planner_issue_test.go | 28 +++++++++++++++++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 216c770e4446a..f007e43b6cc85 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -365,6 +365,9 @@ func (sf *ScalarFunction) Equal(ctx EvalContext, e Expression) bool { if sf.FuncName.L != fun.FuncName.L { return false } + if !sf.RetType.Equal(fun.RetType) { + return false + } return sf.Function.equal(ctx, fun.Function) } diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index d5271860822e3..a0b5ed3654348 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -259,7 +259,7 @@ func TestSubstituteCorCol2Constant(t *testing.T) { ret, err = SubstituteCorCol2Constant(ctx, plus3) require.NoError(t, err) ans3 := newFunctionWithMockCtx(ast.Plus, ans1, col1) - require.True(t, ret.Equal(ctx, ans3)) + require.False(t, ret.Equal(ctx, ans3)) } func TestPushDownNot(t *testing.T) { diff --git a/pkg/parser/types/field_type.go b/pkg/parser/types/field_type.go index 2e80bbf3c7d2b..2befced6393bb 100644 --- a/pkg/parser/types/field_type.go +++ b/pkg/parser/types/field_type.go @@ -289,7 +289,7 @@ func (ft *FieldType) Equal(other *FieldType) bool { // because flen for them is useless. // The decimal field can be ignored if the type is int or string. tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar) - flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) + flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) || ft.EvalType() == ETJson ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString partialEqual := tpEqual && (ignoreDecimal || ft.decimal == other.decimal) && diff --git a/pkg/planner/core/issuetest/planner_issue_test.go b/pkg/planner/core/issuetest/planner_issue_test.go index 68abab69dc621..b0849ae5fca0f 100644 --- a/pkg/planner/core/issuetest/planner_issue_test.go +++ b/pkg/planner/core/issuetest/planner_issue_test.go @@ -58,3 +58,31 @@ func TestIssue43461(t *testing.T) { require.NotEqual(t, is.Columns, ts.Columns) } + +func Test53726(t *testing.T) { + // test for RemoveUnnecessaryFirstRow + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t7(c int); ") + tk.MustExec("insert into t7 values (575932053), (-258025139);") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort().Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_8 8000.00 root group by:Column#7, Column#8, funcs:firstrow(Column#7)->Column#3, funcs:firstrow(Column#8)->Column#4", + "└─TableReader_9 8000.00 root data:HashAgg_4", + " └─HashAgg_4 8000.00 cop[tikv] group by:cast(test.t7.c, bigint(22) BINARY), cast(test.t7.c, decimal(10,0) BINARY), ", + " └─TableFullScan_7 10000.00 cop[tikv] table:t7 keep order:false, stats:pseudo")) + + tk.MustExec("analyze table t7") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort(). + Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_6 2.00 root group by:Column#13, Column#14, funcs:firstrow(Column#11)->Column#3, funcs:firstrow(Column#12)->Column#4", + "└─Projection_12 2.00 root cast(test.t7.c, decimal(10,0) BINARY)->Column#11, cast(test.t7.c, bigint(22) BINARY)->Column#12, cast(test.t7.c, decimal(10,0) BINARY)->Column#13, cast(test.t7.c, bigint(22) BINARY)->Column#14", + " └─TableReader_11 2.00 root data:TableFullScan_10", + " └─TableFullScan_10 2.00 cop[tikv] table:t7 keep order:false")) +}