Skip to content

Commit

Permalink
expression: fix incorrect result when the result of casting const val…
Browse files Browse the repository at this point in the history
…ue to duration type is null (#55454)

close #51842
  • Loading branch information
xzhangxian1008 authored Aug 29, 2024
1 parent 4c23efb commit eaa75a8
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
58 changes: 55 additions & 3 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,48 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy
(conEvalType == types.ETReal || conEvalType == types.ETDecimal || conEvalType == types.ETInt)
}

// handleDurationTypeComparison handles comparisons between a duration type column and a non-duration type constant.
// If the constant cannot be cast to a duration type and the comparison operator is `<=>`, the expression is rewritten as `0 <=> 1`.
// This is necessary to maintain compatibility with MySQL behavior under the following conditions:
// 1. When a duration type column is compared with a non-duration type constant, MySQL casts the duration column to the non-duration type.
// This cast prevents the use of indexes on the duration column. In TiDB, we instead cast the non-duration type constant to the duration type.
// 2. If the non-duration type constant cannot be successfully cast to a duration type, the cast returns null. A duration type constant, however,
// can always be cast to a non-duration type without returning null.
// 3. If the duration type column's value is null and the non-duration type constant cannot be cast to a duration type, and the comparison operator
// is `<=>` (null equal), then in TiDB, `durationColumn <=> non-durationTypeConstant` evaluates to `null <=> null`, returning true. In MySQL,
// it would evaluate to `null <=> not-null constant`, returning false.
//
// To ensure MySQL compatibility, we need to handle this case specifically. If the non-duration type constant cannot be cast to a duration type,
// we rewrite the expression to always return false by converting it to `0 <=> 1`.
func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) {
// check if a constant value becomes null after being cast to a duration type.
castToDurationIsNull := func(ctx BuildContext, arg Expression) (bool, error) {
f := WrapWithCastAsDuration(ctx, arg)
_, isNull, err := f.EvalDuration(ctx.GetEvalCtx(), chunk.Row{})
if err != nil {
return false, err
}
return isNull, nil
}

arg0Const, arg0IsCon := arg0.(*Constant)
arg1Const, arg1IsCon := arg1.(*Constant)

var isNull bool
if arg0IsCon && arg0Const.DeferredExpr == nil && !arg1IsCon && arg1.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration {
isNull, err = castToDurationIsNull(ctx, arg0)
} else if arg1IsCon && arg1Const.DeferredExpr == nil && !arg0IsCon && arg0.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration {
isNull, err = castToDurationIsNull(ctx, arg1)
}
if err != nil {
return nil, err
}
if isNull {
return []Expression{NewZero(), NewOne()}, nil
}
return nil, nil
}

// Since the argument refining of cmp functions can bring some risks to the plan-cache, the optimizer
// needs to decide to whether to skip the refining or skip plan-cache for safety.
// For example, `unsigned_int_col > ?(-1)` can be refined to `True`, but the validation of this result
Expand Down Expand Up @@ -1654,9 +1696,12 @@ func allowCmpArgsRefining4PlanCache(ctx BuildContext, args []Expression) (allowR
}

// refineArgs will rewrite the arguments if the compare expression is
// 1. `int column <cmp> non-int constant` or `non-int constant <cmp> int column`. E.g., `a < 1.1` will be rewritten to `a < 2`.
// 2. It also handles comparing year type with int constant if the int constant falls into a sensible year representation.
// 3. It also handles comparing datetime/timestamp column with numeric constant, try to cast numeric constant as timestamp type, do nothing if failed.
// 1. `int column <cmp> non-int constant` or `non-int constant <cmp> int column`. E.g., `a < 1.1` will be rewritten to `a < 2`.
// 2. It also handles comparing year type with int constant if the int constant falls into a sensible year representation.
// 3. It also handles comparing datetime/timestamp column with numeric constant, try to cast numeric constant as timestamp type, do nothing if failed.
// 4. Handles special cases where a duration type column is compared with a non-duration type constant, particularly when the constant
// cannot be cast to a duration type, ensuring compatibility with MySQL’s behavior by rewriting the expression as `0 <=> 1`.
//
// This refining operation depends on the values of these args, but these values can change when using plan-cache.
// So we have to skip this operation or mark the plan as over-optimized when using plan-cache.
func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ([]Expression, error) {
Expand All @@ -1677,6 +1722,13 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) (
return nil, err
}

// Handle comparison between a duration type column and a non-duration type constant.
if c.op == opcode.NullEQ {
if result, err := c.handleDurationTypeComparison(ctx, args[0], args[1]); err != nil || result != nil {
return result, err
}
}

if arg0IsCon && !arg1IsCon && matchRefineRule3Pattern(arg0EvalType, arg1Type) {
return c.refineNumericConstantCmpDatetime(ctx, args, arg0, 0), nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 41,
shard_count = 42,
deps = [
"//pkg/config",
"//pkg/domain",
Expand Down
24 changes: 24 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3547,3 +3547,27 @@ func TestIssue43527(t *testing.T) {
"SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := b FROM test) AS T1 where @total >= 100",
).Check(testkit.Rows("200", "300", "400", "500"))
}

func TestIssue51842(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t;")
tk.MustExec("CREATE TABLE t0(c0 DOUBLE);")
tk.MustExec("REPLACE INTO t0(c0) VALUES (0.40194983109852933);")
tk.MustExec("CREATE VIEW v0(c0) AS SELECT CAST(')' AS TIME) FROM t0 WHERE '0.030417148673465677';")
res := tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> 1292367147;").String() // test int
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(123988.42132 as real);").String() // test real
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(123988.42132 as decimal);").String() // test decimal
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('fdasge' as char);").String() // test string
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('10:10:10' as time);").String() // test time
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast(2024 as year);").String() // test year
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('2024-1-1 10:10:10' as datetime);").String() // test datetime
require.Equal(t, 0, len(res))
}

0 comments on commit eaa75a8

Please sign in to comment.