Skip to content

Commit

Permalink
expression: StaticExprContext/EvalContext support to load state fro…
Browse files Browse the repository at this point in the history
…m system variables (#55800)

close #55799
  • Loading branch information
lcwangchao authored Sep 4, 2024
1 parent e1e9e16 commit e77d4a1
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/expression/contextstatic/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ go_test(
],
embed = [":contextstatic"],
flaky = True,
shard_count = 11,
shard_count = 13,
deps = [
"//pkg/errctx",
"//pkg/expression/context",
Expand Down
62 changes: 62 additions & 0 deletions pkg/expression/contextstatic/evalctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package contextstatic

import (
"math"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -418,6 +420,66 @@ func (ctx *StaticEvalContext) GetWarnHandler() contextutil.WarnHandler {
return ctx.warnHandler
}

// LoadSystemVars loads system variables and returns a new `StaticEvalContext` with system variables loaded.
func (ctx *StaticEvalContext) LoadSystemVars(sysVars map[string]string) (*StaticEvalContext, error) {
sessionVars, err := newSessionVarsWithSystemVariables(sysVars)
if err != nil {
return nil, err
}
return ctx.loadSessionVarsInternal(sessionVars, sysVars), nil
}

func (ctx *StaticEvalContext) loadSessionVarsInternal(
sessionVars *variable.SessionVars, sysVars map[string]string,
) *StaticEvalContext {
opts := make([]StaticEvalCtxOption, 0, 8)
for name, val := range sysVars {
name = strings.ToLower(name)
switch name {
case variable.TimeZone:
opts = append(opts, WithLocation(sessionVars.Location()))
case variable.SQLModeVar:
opts = append(opts, WithSQLMode(sessionVars.SQLMode))
case variable.Timestamp:
opts = append(opts, WithCurrentTime(ctx.currentTimeFuncFromStringVal(val)))
case variable.MaxAllowedPacket:
opts = append(opts, WithMaxAllowedPacket(sessionVars.MaxAllowedPacket))
case variable.TiDBRedactLog:
opts = append(opts, WithEnableRedactLog(sessionVars.EnableRedactLog))
case variable.DefaultWeekFormat:
opts = append(opts, WithDefaultWeekFormatMode(val))
case variable.DivPrecisionIncrement:
opts = append(opts, WithDivPrecisionIncrement(sessionVars.DivPrecisionIncrement))
}
}
return ctx.Apply(opts...)
}

func (ctx *StaticEvalContext) currentTimeFuncFromStringVal(val string) func() (time.Time, error) {
return func() (time.Time, error) {
if val == variable.DefTimestamp {
return time.Now(), nil
}

ts, err := types.StrToFloat(types.StrictContext, val, false)
if err != nil {
return time.Time{}, err
}
seconds, fractionalSeconds := math.Modf(ts)
return time.Unix(int64(seconds), int64(fractionalSeconds*float64(time.Second))), nil
}
}

func newSessionVarsWithSystemVariables(vars map[string]string) (*variable.SessionVars, error) {
sessionVars := variable.NewSessionVars(nil)
for name, val := range vars {
if err := sessionVars.SetSystemVar(name, val); err != nil {
return nil, err
}
}
return sessionVars, nil
}

// MakeEvalContextStatic converts the `exprctx.StaticConvertibleEvalContext` to `StaticEvalContext`.
func MakeEvalContextStatic(ctx exprctx.StaticConvertibleEvalContext) *StaticEvalContext {
typeCtx := ctx.TypeCtx()
Expand Down
160 changes: 160 additions & 0 deletions pkg/expression/contextstatic/evalctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package contextstatic

import (
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -532,3 +533,162 @@ func TestMakeEvalContextStatic(t *testing.T) {
// Now, it didn't copy any optional properties.
require.Equal(t, context.OptionalEvalPropKeySet(0), staticObj.GetOptionalPropSet())
}

func TestEvalCtxLoadSystemVars(t *testing.T) {
vars := []struct {
name string
val string
field string
assert func(ctx *StaticEvalContext, vars *variable.SessionVars)
}{
{
name: "time_zone",
val: "Europe/Berlin",
field: "$.typeCtx.loc",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, "Europe/Berlin", ctx.Location().String())
require.Equal(t, vars.Location().String(), ctx.Location().String())
},
},
{
name: "sql_mode",
val: "ALLOW_INVALID_DATES,ONLY_FULL_GROUP_BY",
field: "$.sqlMode",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, mysql.ModeAllowInvalidDates|mysql.ModeOnlyFullGroupBy, ctx.SQLMode())
require.Equal(t, vars.SQLMode, ctx.SQLMode())
},
},
{
name: "timestamp",
val: "1234567890.123456",
field: "$.currentTime",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
currentTime, err := ctx.CurrentTime()
require.NoError(t, err)
require.Equal(t, int64(1234567890123456), currentTime.UnixMicro())
require.Equal(t, vars.Location().String(), currentTime.Location().String())
},
},
{
name: strings.ToUpper("max_allowed_packet"), // test for settings an upper case variable
val: "524288",
field: "$.maxAllowedPacket",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, uint64(524288), ctx.GetMaxAllowedPacket())
require.Equal(t, vars.MaxAllowedPacket, ctx.GetMaxAllowedPacket())
},
},
{
name: strings.ToUpper("tidb_redact_log"), // test for settings an upper case variable
val: "on",
field: "$.enableRedactLog",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, "ON", ctx.GetTiDBRedactLog())
require.Equal(t, vars.EnableRedactLog, ctx.GetTiDBRedactLog())
},
},
{
name: "default_week_format",
val: "5",
field: "$.defaultWeekFormatMode",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, "5", ctx.GetDefaultWeekFormatMode())
mode, ok := vars.GetSystemVar(variable.DefaultWeekFormat)
require.True(t, ok)
require.Equal(t, mode, ctx.GetDefaultWeekFormatMode())
},
},
{
name: "div_precision_increment",
val: "12",
field: "$.divPrecisionIncrement",
assert: func(ctx *StaticEvalContext, vars *variable.SessionVars) {
require.Equal(t, 12, ctx.GetDivPrecisionIncrement())
require.Equal(t, vars.DivPrecisionIncrement, ctx.GetDivPrecisionIncrement())
},
},
}

// nonVarRelatedFields means the fields not related to any system variables.
// To make sure that all the variables which affect the context state are covered in the above test list,
// we need to test all inner fields except those in `nonVarRelatedFields` are changed after `LoadSystemVars`.
nonVarRelatedFields := []string{
"$.warnHandler",
"$.typeCtx.flags",
"$.typeCtx.warnHandler",
"$.errCtx",
"$.currentDB",
"$.requestVerificationFn",
"$.requestDynamicVerificationFn",
"$.paramList",
"$.props",
}

// varsRelatedFields means the fields related to
varsRelatedFields := make([]string, 0, len(vars))
varsMap := make(map[string]string)
sessionVars := variable.NewSessionVars(nil)
for _, sysVar := range vars {
varsMap[sysVar.name] = sysVar.val
if sysVar.field != "" {
varsRelatedFields = append(varsRelatedFields, sysVar.field)
}
require.NoError(t, sessionVars.SetSystemVar(sysVar.name, sysVar.val))
}

defaultEvalCtx := NewStaticEvalContext()
ctx, err := defaultEvalCtx.LoadSystemVars(varsMap)
require.NoError(t, err)
require.Greater(t, ctx.CtxID(), defaultEvalCtx.CtxID())

// Check all fields except these in `nonVarRelatedFields` are changed after `LoadSystemVars` to make sure
// all system variables related fields are covered in the test list.
deeptest.AssertRecursivelyNotEqual(
t,
defaultEvalCtx.staticEvalCtxState,
ctx.staticEvalCtxState,
deeptest.WithIgnorePath(nonVarRelatedFields),
deeptest.WithPointerComparePath([]string{"$.currentTime"}),
)

// We need to compare the new context again with an empty one to make sure those values are set from sys vars,
// not inherited from the empty go value.
deeptest.AssertRecursivelyNotEqual(
t,
staticEvalCtxState{},
ctx.staticEvalCtxState,
deeptest.WithIgnorePath(nonVarRelatedFields),
deeptest.WithPointerComparePath([]string{"$.currentTime"}),
)

// Check all system vars unrelated fields are not changed after `LoadSystemVars`.
deeptest.AssertDeepClonedEqual(
t,
defaultEvalCtx.staticEvalCtxState,
ctx.staticEvalCtxState,
deeptest.WithIgnorePath(append(
varsRelatedFields,
// Do not check warnHandler in `typeCtx` and `errCtx` because they should be changed to even if
// they are not related to any system variable.
"$.typeCtx.warnHandler",
"$.errCtx.warnHandler",
)),
// LoadSystemVars only does shallow copy for `EvalContext` so we just need to compare the pointers.
deeptest.WithPointerComparePath(nonVarRelatedFields),
)

for _, sysVar := range vars {
sysVar.assert(ctx, sessionVars)
}

// additional check about @@timestamp
// setting to `variable.DefTimestamp` should return the current timestamp
ctx, err = defaultEvalCtx.LoadSystemVars(map[string]string{
"timestamp": variable.DefTimestamp,
})
require.NoError(t, err)
tm, err := ctx.CurrentTime()
require.NoError(t, err)
require.InDelta(t, time.Now().Unix(), tm.Unix(), 5)
}
42 changes: 42 additions & 0 deletions pkg/expression/contextstatic/exprctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package contextstatic

import (
"strings"

exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -317,3 +319,43 @@ func MakeExprContextStatic(ctx exprctx.StaticConvertibleExprContext) *StaticExpr
WithGroupConcatMaxLen(ctx.GetGroupConcatMaxLen()),
)
}

// LoadSystemVars loads system variables and returns a new `StaticEvalContext` with system variables loaded.
func (ctx *StaticExprContext) LoadSystemVars(sysVars map[string]string) (*StaticExprContext, error) {
sessionVars, err := newSessionVarsWithSystemVariables(sysVars)
if err != nil {
return nil, err
}
return ctx.loadSessionVarsInternal(sessionVars, sysVars), nil
}

func (ctx *StaticExprContext) loadSessionVarsInternal(
sessionVars *variable.SessionVars, sysVars map[string]string,
) *StaticExprContext {
opts := make([]StaticExprCtxOption, 0, 8)
opts = append(opts, WithEvalCtx(ctx.evalCtx.loadSessionVarsInternal(sessionVars, sysVars)))
for name := range sysVars {
name = strings.ToLower(name)
switch name {
case variable.CharacterSetConnection, variable.CollationConnection:
opts = append(opts, WithCharset(sessionVars.GetCharsetInfo()))
case variable.DefaultCollationForUTF8MB4:
opts = append(opts, WithDefaultCollationForUTF8MB4(sessionVars.DefaultCollationForUTF8MB4))
case variable.BlockEncryptionMode:
blockMode, ok := sessionVars.GetSystemVar(variable.BlockEncryptionMode)
intest.Assert(ok)
if ok {
opts = append(opts, WithBlockEncryptionMode(blockMode))
}
case variable.TiDBSysdateIsNow:
opts = append(opts, WithSysDateIsNow(sessionVars.SysdateIsNow))
case variable.TiDBEnableNoopFuncs:
opts = append(opts, WithNoopFuncsMode(sessionVars.NoopFuncsMode))
case variable.WindowingUseHighPrecision:
opts = append(opts, WithWindowingUseHighPrecision(sessionVars.WindowingUseHighPrecision))
case variable.GroupConcatMaxLen:
opts = append(opts, WithGroupConcatMaxLen(sessionVars.GroupConcatMaxLen))
}
}
return ctx.Apply(opts...)
}
Loading

0 comments on commit e77d4a1

Please sign in to comment.