diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index d5e3f42515557..5def84911dd47 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -1052,7 +1052,11 @@ func (b *builtinGetRealVarSig) evalReal(ctx EvalContext, row chunk.Row) (float64 } varName = strings.ToLower(varName) if v, ok := sessionVars.GetUserVarVal(varName); ok { - return v.GetFloat64(), false, nil + d, err := v.ToFloat64(typeCtx(ctx)) + if err != nil { + return 0, false, err + } + return d, false, nil } return 0, true, nil } @@ -1092,7 +1096,11 @@ func (b *builtinGetDecimalVarSig) evalDecimal(ctx EvalContext, row chunk.Row) (* } varName = strings.ToLower(varName) if v, ok := sessionVars.GetUserVarVal(varName); ok { - return v.GetMysqlDecimal(), false, nil + d, err := v.ToDecimal(typeCtx(ctx)) + if err != nil { + return nil, false, err + } + return d, false, nil } return nil, true, nil } diff --git a/pkg/expression/builtin_other_test.go b/pkg/expression/builtin_other_test.go index 7df999b8b6c09..7e9ac375964ea 100644 --- a/pkg/expression/builtin_other_test.go +++ b/pkg/expression/builtin_other_test.go @@ -169,6 +169,33 @@ func TestGetVar(t *testing.T) { } } +func TestTypeConversion(t *testing.T) { + ctx := createContext(t) + // Set value as int64 + key := "a" + val := int64(3) + ctx.GetSessionVars().SetUserVarVal(key, types.NewDatum(val)) + tp := types.NewFieldType(mysql.TypeLonglong) + ctx.GetSessionVars().SetUserVarType(key, tp) + + args := []any{"a"} + // To Decimal. + tp = types.NewFieldType(mysql.TypeNewDecimal) + fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp) + require.NoError(t, err) + d, err := fn.Eval(ctx, chunk.Row{}) + require.NoError(t, err) + des := types.NewDecFromInt(3) + require.Equal(t, des, d.GetValue()) + // To Float. + tp = types.NewFieldType(mysql.TypeDouble) + fn, err = BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp) + require.NoError(t, err) + d, err = fn.Eval(ctx, chunk.Row{}) + require.NoError(t, err) + require.Equal(t, float64(3), d.GetValue()) +} + func TestValues(t *testing.T) { ctx := createContext(t) fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)} diff --git a/pkg/expression/builtin_other_vec.go b/pkg/expression/builtin_other_vec.go index bbcf8b9783a26..aeff84ed2ca04 100644 --- a/pkg/expression/builtin_other_vec.go +++ b/pkg/expression/builtin_other_vec.go @@ -386,6 +386,8 @@ func (b *builtinGetRealVarSig) vectorized() bool { return true } +// NOTE: get/set variable vectorized eval was disabled. See more in +// https://github.com/pingcap/tidb/pull/8412 func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get() @@ -406,7 +408,11 @@ func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk, } varName := strings.ToLower(buf0.GetString(i)) if v, ok := sessionVars.GetUserVarVal(varName); ok { - f64s[i] = v.GetFloat64() + d, err := v.ToFloat64(typeCtx(ctx)) + if err != nil { + return err + } + f64s[i] = d continue } result.SetNull(i, true) @@ -418,6 +424,8 @@ func (b *builtinGetDecimalVarSig) vectorized() bool { return true } +// NOTE: get/set variable vectorized eval was disabled. See more in +// https://github.com/pingcap/tidb/pull/8412 func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf0, err := b.bufAllocator.get() @@ -438,7 +446,11 @@ func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.C } varName := strings.ToLower(buf0.GetString(i)) if v, ok := sessionVars.GetUserVarVal(varName); ok { - decs[i] = *v.GetMysqlDecimal() + d, err := v.ToDecimal(typeCtx(ctx)) + if err != nil { + return err + } + decs[i] = *d continue } result.SetNull(i, true) diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index 91d96dd1cf6d8..b7b517b8762fb 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 24, + shard_count = 25, deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 6d011e3eee624..34850e12c208d 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -2976,3 +2976,27 @@ func TestTiDBRowChecksumBuiltin(t *testing.T) { tk.MustGetDBError("select tidb_row_checksum() from t", expression.ErrNotSupportedYet) tk.MustGetDBError("select tidb_row_checksum() from t where id > 0", expression.ErrNotSupportedYet) } + +func TestIssue43527(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table test (a datetime, b bigint, c decimal(10, 2), d float)") + tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)") + // Decimal. + tk.MustQuery( + "SELECT @total := @total + c FROM (SELECT c FROM test) AS temp, (SELECT @total := 200) AS T1", + ).Check(testkit.Rows("300.01")) + // Float. + tk.MustQuery( + "SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := 200) AS T1", + ).Check(testkit.Rows("300")) + tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)") + // Vectorized. + // NOTE: Because https://github.com/pingcap/tidb/pull/8412 disabled the vectorized execution of get or set variable, + // the following test case will not be executed in vectorized mode. + // It will be executed in the normal mode. + tk.MustQuery( + "SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := b FROM test) AS T1 where @total >= 100", + ).Check(testkit.Rows("200", "300", "400", "500")) +}