diff --git a/expression/builtin_math.go b/expression/builtin_math.go index 04c59608d3bbd..68ee63a974114 100644 --- a/expression/builtin_math.go +++ b/expression/builtin_math.go @@ -110,6 +110,7 @@ var ( _ builtinFunc = &builtinTruncateIntSig{} _ builtinFunc = &builtinTruncateRealSig{} _ builtinFunc = &builtinTruncateDecimalSig{} + _ builtinFunc = &builtinTruncateUintSig{} ) type absFunctionClass struct { @@ -1737,7 +1738,11 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre var sig builtinFunc switch argTp { case types.ETInt: - sig = &builtinTruncateIntSig{bf} + if mysql.HasUnsignedFlag(args[0].GetType().Flag) { + sig = &builtinTruncateUintSig{bf} + } else { + sig = &builtinTruncateIntSig{bf} + } case types.ETReal: sig = &builtinTruncateRealSig{bf} case types.ETDecimal: @@ -1826,6 +1831,39 @@ func (b *builtinTruncateIntSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull, errors.Trace(err) } - floatX := float64(x) - return int64(types.Truncate(floatX, int(d))), false, nil + if d >= 0 { + return x, false, nil + } + shift := int64(math.Pow10(int(-d))) + return x / shift * shift, false, nil +} + +func (b *builtinTruncateUintSig) Clone() builtinFunc { + newSig := &builtinTruncateUintSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +type builtinTruncateUintSig struct { + baseBuiltinFunc +} + +// evalInt evals a TRUNCATE(X,D). +// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_truncate +func (b *builtinTruncateUintSig) evalInt(row chunk.Row) (int64, bool, error) { + x, isNull, err := b.args[0].EvalInt(b.ctx, row) + if isNull || err != nil { + return 0, isNull, errors.Trace(err) + } + uintx := uint64(x) + + d, isNull, err := b.args[1].EvalInt(b.ctx, row) + if isNull || err != nil { + return 0, isNull, errors.Trace(err) + } + if d >= 0 { + return x, false, nil + } + shift := uint64(math.Pow10(int(-d))) + return int64(uintx / shift * shift), false, nil } diff --git a/expression/builtin_math_test.go b/expression/builtin_math_test.go index a0c291603a5c3..db01a7864dfe5 100644 --- a/expression/builtin_math_test.go +++ b/expression/builtin_math_test.go @@ -485,6 +485,9 @@ func (s *testEvaluatorSuite) TestTruncate(c *C) { {[]interface{}{newDec("23.298"), -100}, newDec("0")}, {[]interface{}{newDec("23.298"), 100}, newDec("23.298")}, {[]interface{}{nil, 2}, nil}, + {[]interface{}{uint64(9223372036854775808), -10}, 9223372030000000000}, + {[]interface{}{9223372036854775807, -7}, 9223372036850000000}, + {[]interface{}{uint64(18446744073709551615), -10}, uint64(18446744070000000000)}, } Dtbl := tblToDtbl(tbl) diff --git a/expression/integration_test.go b/expression/integration_test.go index c81203e1518b3..7ee4d4c4e26c0 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -477,6 +477,8 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) { result.Check(testkit.Rows("100 123 123 120")) result = tk.MustQuery("SELECT truncate(123.456, -2), truncate(123.456, 2), truncate(123.456, 1), truncate(123.456, 3), truncate(1.23, 100), truncate(123456E-3, 2);") result.Check(testkit.Rows("100 123.45 123.4 123.456 1.230000000000000000000000000000 123.45")) + result = tk.MustQuery("SELECT truncate(9223372036854775807, -7), truncate(9223372036854775808, -10), truncate(cast(-1 as unsigned), -10);") + result.Check(testkit.Rows("9223372036850000000 9223372030000000000 18446744070000000000")) tk.MustExec(`drop table if exists t;`) tk.MustExec(`create table t(a date, b datetime, c timestamp, d varchar(20));`)