Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: expression.BuildSimpleExpr supports to build ParamMarker #55493

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/expression/contextstatic/exprctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ func (ctx *StaticExprContext) GetEvalCtx() exprctx.EvalContext {
return ctx.evalCtx
}

// GetStaticEvalCtx returns the inner `StaticEvalContext`.
func (ctx *StaticExprContext) GetStaticEvalCtx() *StaticEvalContext {
return ctx.evalCtx
}

// GetCharsetInfo implements the `ExprContext.GetCharsetInfo`.
func (ctx *StaticExprContext) GetCharsetInfo() (string, string) {
return ctx.charset, ctx.collation
Expand Down
9 changes: 4 additions & 5 deletions pkg/expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -1196,8 +1195,8 @@ func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant {
}

// ParamMarkerExpression generate a getparam function expression.
func ParamMarkerExpression(ctx variable.SessionVarsProvider, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) {
useCache := ctx.GetSessionVars().StmtCtx.UseCache()
func ParamMarkerExpression(ctx BuildContext, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) {
useCache := ctx.IsUseCache()
tp := types.NewFieldType(mysql.TypeUnspecified)
types.InferParamTypeFromDatum(&v.Datum, tp)
value := &Constant{Value: v.Datum, RetType: tp}
Expand Down Expand Up @@ -1251,11 +1250,11 @@ func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
}

// PosFromPositionExpr generates a position value from PositionExpr.
func PosFromPositionExpr(ctx BuildContext, vars variable.SessionVarsProvider, v *ast.PositionExpr) (int, bool, error) {
func PosFromPositionExpr(ctx BuildContext, v *ast.PositionExpr) (int, bool, error) {
if v.P == nil {
return v.N, false, nil
}
value, err := ParamMarkerExpression(vars, v.P.(*driver.ParamMarkerExpr), false)
value, err := ParamMarkerExpression(ctx, v.P.(*driver.ParamMarkerExpr), false)
if err != nil {
return 0, true, err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ go_test(
"//pkg/domain",
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/context",
"//pkg/expression/contextstatic",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/metrics",
Expand Down
28 changes: 15 additions & 13 deletions pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1445,19 +1445,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
}
er.ctxStackAppend(value, types.EmptyName)
case *driver.ParamMarkerExpr:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
var value *expression.Constant
value, er.err = expression.ParamMarkerExpression(planCtx.builder.ctx, v, false)
if er.err != nil {
return
}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(value.RetType)
if er.err != nil {
return
}
er.ctxStackAppend(value, types.EmptyName)
})
er.toParamMarker(v)
case *ast.VariableExpr:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
er.rewriteVariable(planCtx, v)
Expand Down Expand Up @@ -2407,6 +2395,20 @@ func (er *expressionRewriter) toTable(v *ast.TableName) {
er.ctxStackAppend(val, types.EmptyName)
}

func (er *expressionRewriter) toParamMarker(v *driver.ParamMarkerExpr) {
var value *expression.Constant
value, er.err = expression.ParamMarkerExpression(er.sctx, v, false)
if er.err != nil {
return
}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(value.RetType)
if er.err != nil {
return
}
er.ctxStackAppend(value, types.EmptyName)
}

func (er *expressionRewriter) clause() clauseCode {
if er.planCtx != nil {
return er.planCtx.builder.curClause
Expand Down
57 changes: 33 additions & 24 deletions pkg/planner/core/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ import (

"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit/testutil"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -404,39 +407,36 @@ func TestBuildExpression(t *testing.T) {
},
}

ctx := MockContext()
defer func() {
domain.GetDomain(ctx).StatsHandle().Close()
}()

ctx := contextstatic.NewStaticExprContext()
evalCtx := ctx.GetStaticEvalCtx()
cols, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr(""), tbl.Name, tbl.Cols(), tbl)
require.NoError(t, err)
schema := expression.NewSchema(cols...)

// normal build
ctx.GetSessionVars().PlanColumnID.Store(0)
ctx = ctx.Apply(contextstatic.WithColumnIDAllocator(context.NewSimplePlanColumnIDAllocator(0)))
expr, err := buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
ctx.GetSessionVars().PlanColumnID.Store(0)
ctx = ctx.Apply(contextstatic.WithColumnIDAllocator(context.NewSimplePlanColumnIDAllocator(0)))
expr2, err := expression.ParseSimpleExpr(ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
require.True(t, expr.Equal(ctx, expr2))
val, _, err := expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.True(t, expr.Equal(evalCtx, expr2))
val, _, err := expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)
val, _, err = expr.EvalInt(ctx, chunk.MutRowFromValues("", 3, 4).ToRow())
val, _, err = expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 3, 4).ToRow())
require.NoError(t, err)
require.Equal(t, int64(28), val)
val, _, err = expr2.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
val, _, err = expr2.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)
val, _, err = expr2.EvalInt(ctx, chunk.MutRowFromValues("", 3, 4).ToRow())
val, _, err = expr2.EvalInt(evalCtx, chunk.MutRowFromValues("", 3, 4).ToRow())
require.NoError(t, err)
require.Equal(t, int64(28), val)

expr, err = buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.NoError(t, err)
val, _, err = expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
val, _, err = expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)

Expand All @@ -452,52 +452,61 @@ func TestBuildExpression(t *testing.T) {
// use WithAllowCastArray to allow casting to array
expr, err = buildExpr(t, ctx, `cast(json_extract('{"a": [1, 2, 3]}', '$.a') as signed array)`, expression.WithAllowCastArray(true))
require.NoError(t, err)
j, _, err := expr.EvalJSON(ctx, chunk.Row{})
j, _, err := expr.EvalJSON(evalCtx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.JSONTypeCodeArray, j.TypeCode)
require.Equal(t, "[1, 2, 3]", j.String())

// default expr
expr, err = buildExpr(t, ctx, "default(id)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
s, _, err := expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
s, _, err := expr.EvalString(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, err = buildExpr(t, ctx, "default(id)", expression.WithInputSchemaAndNames(schema, names, tbl))
require.NoError(t, err)
s, _, err = expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
s, _, err = expr.EvalString(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, err = buildExpr(t, ctx, "default(b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
d, err := expr.Eval(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
d, err := expr.Eval(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, types.NewDatum(int64(123)), d)

// WithCastExprTo
expr, err = buildExpr(t, ctx, "1+2+3")
require.NoError(t, err)
require.Equal(t, mysql.TypeLonglong, expr.GetType(ctx).GetType())
require.Equal(t, mysql.TypeLonglong, expr.GetType(evalCtx).GetType())
castTo := types.NewFieldType(mysql.TypeVarchar)
expr, err = buildExpr(t, ctx, "1+2+3", expression.WithCastExprTo(castTo))
require.NoError(t, err)
require.Equal(t, mysql.TypeVarchar, expr.GetType(ctx).GetType())
v, err := expr.Eval(ctx, chunk.Row{})
require.Equal(t, mysql.TypeVarchar, expr.GetType(evalCtx).GetType())
v, err := expr.Eval(evalCtx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.KindString, v.Kind())
require.Equal(t, "6", v.GetString())

// param marker
params := variable.NewPlanCacheParamList()
params.Append(types.NewIntDatum(5))
evalCtx = evalCtx.Apply(contextstatic.WithParamList(params))
ctx = ctx.Apply(contextstatic.WithEvalCtx(evalCtx))
expr, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
require.NoError(t, err)
require.Equal(t, mysql.TypeLonglong, expr.GetType(evalCtx).GetType())
v, err = expr.Eval(evalCtx, chunk.MutRowFromValues(1, 2, 3).ToRow())
require.NoError(t, err)
require.Equal(t, types.KindInt64, v.Kind())
require.Equal(t, int64(7), v.GetInt64())

// should report error for default expr when source table not provided
_, err = buildExpr(t, ctx, "default(b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.EqualError(t, err, "Unsupported expr *ast.DefaultExpr when source table not provided")

// subquery not supported
_, err = buildExpr(t, ctx, "a + (select b from t)", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*ast.SubqueryExpr' is not allowed when building an expression without planner")

// param marker not supported
_, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*driver.ParamMarkerExpr' is not allowed when building an expression without planner")
}
23 changes: 11 additions & 12 deletions pkg/planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) {
a.exprDepth++
if n, ok := inNode.(*driver.ParamMarkerExpr); ok {
if a.exprDepth == 1 {
_, isNull, isExpectedType := getUintFromNode(a.ctx, n, false)
_, isNull, isExpectedType := getUintFromNode(a.ctx.GetExprCtx(), n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand All @@ -109,7 +109,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) {

func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) {
if v, ok := inNode.(*ast.PositionExpr); ok {
pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), a.ctx, v)
pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), v)
if err != nil {
a.err = err
}
Expand Down Expand Up @@ -2006,7 +2006,7 @@ CheckReferenced:
// getUintFromNode gets uint64 value from ast.Node.
// For ordinary statement, node should be uint64 constant value.
// For prepared statement, node is string. We should convert it to uint64.
func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) {
func getUintFromNode(ctx expression.BuildContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) {
var val any
switch v := n.(type) {
case *driver.ValueExpr:
Expand All @@ -2024,7 +2024,7 @@ func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (
if err != nil {
return 0, false, false
}
str, isNull, err := expression.GetStringFromConstant(ctx.GetExprCtx().GetEvalCtx(), param)
str, isNull, err := expression.GetStringFromConstant(ctx.GetEvalCtx(), param)
if err != nil {
return 0, false, false
}
Expand All @@ -2043,8 +2043,7 @@ func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (
return uint64(v), false, true
}
case string:
ctx := ctx.GetSessionVars().StmtCtx.TypeCtx()
uVal, err := types.StrToUint(ctx, v, false)
uVal, err := types.StrToUint(ctx.GetEvalCtx().TypeCtx(), v, false)
if err != nil {
return 0, false, false
}
Expand All @@ -2068,7 +2067,7 @@ func CheckParamTypeInt64orUint64(param *driver.ParamMarkerExpr) (bool, uint64) {
return false, 0
}

func extractLimitCountOffset(ctx base.PlanContext, limit *ast.Limit) (count uint64,
func extractLimitCountOffset(ctx expression.BuildContext, limit *ast.Limit) (count uint64,
offset uint64, err error) {
var isExpectedType bool
if limit.Count != nil {
Expand All @@ -2092,7 +2091,7 @@ func (b *PlanBuilder) buildLimit(src base.LogicalPlan, limit *ast.Limit) (base.L
offset, count uint64
err error
)
if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil {
if count, offset, err = extractLimitCountOffset(b.ctx.GetExprCtx(), limit); err != nil {
return nil, err
}

Expand Down Expand Up @@ -2845,7 +2844,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) {
case *driver.ParamMarkerExpr:
g.isParam = true
if g.exprDepth == 1 && !n.UseAsValueInGbyByClause {
_, isNull, isExpectedType := getUintFromNode(g.ctx, n, false)
_, isNull, isExpectedType := getUintFromNode(g.ctx.GetExprCtx(), n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand Down Expand Up @@ -2892,7 +2891,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
return inNode, false
}
case *ast.PositionExpr:
pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), g.ctx, v)
pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), v)
if err != nil {
g.err = plannererrors.ErrUnknown.GenWithStackByArgs()
}
Expand Down Expand Up @@ -6069,7 +6068,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast
if bound.Type == ast.CurrentRow {
return bound, nil
}
numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr, false)
numRows, _, _ := getUintFromNode(b.ctx.GetExprCtx(), boundClause.Expr, false)
bound.Num = numRows
return bound, nil
}
Expand Down Expand Up @@ -6391,7 +6390,7 @@ func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *a
if bound.Unit != ast.TimeUnitInvalid {
return plannererrors.ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O))
}
_, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr, false)
_, isNull, isExpectedType := getUintFromNode(b.ctx.GetExprCtx(), bound.Expr, false)
if isNull || !isExpectedType {
return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4011,7 +4011,7 @@ func (b PlanBuilder) getInsertColExpr(ctx context.Context, insertPlan *Insert, m
RetType: &x.Type,
}
case *driver.ParamMarkerExpr:
outExpr, err = expression.ParamMarkerExpression(b.ctx, x, false)
outExpr, err = expression.ParamMarkerExpression(b.ctx.GetExprCtx(), x, false)
default:
b.curClause = fieldList
// subquery in insert values should not reference upper scope
Expand Down
Loading