From 65d740f5f84a1e3db86bb59339fc4ed2ed07b00c Mon Sep 17 00:00:00 2001 From: EricZequan <110292382+EricZequan@users.noreply.github.com> Date: Sun, 29 Sep 2024 19:51:19 +0800 Subject: [PATCH] planner: Display truncate vector in EXPLAIN (#55934) ref pingcap/tidb#54245 --- pkg/executor/importer/import.go | 12 ++-- pkg/executor/testdata/prepare_suite_out.json | 24 +++---- pkg/expression/constant.go | 4 +- pkg/expression/explain.go | 12 +--- pkg/expression/integration_test/BUILD.bazel | 3 +- .../integration_test/integration_test.go | 62 +++++++++++++++++++ pkg/expression/util_test.go | 1 + pkg/types/datum.go | 31 ++++++++++ pkg/types/vector.go | 33 ++++++++++ 9 files changed, 153 insertions(+), 29 deletions(-) diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index c4556b39f4790..6ef74ba91f0a3 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -16,6 +16,7 @@ package importer import ( "context" + "fmt" "io" "math" "net/url" @@ -810,13 +811,14 @@ func (p *Plan) initParameters(plan *plannercore.ImportInto) error { setClause = sb.String() } optionMap := make(map[string]any, len(plan.Options)) - var evalCtx expression.EvalContext - if plan.SCtx() != nil { - evalCtx = plan.SCtx().GetExprCtx().GetEvalCtx() - } for _, opt := range plan.Options { if opt.Value != nil { - val := opt.Value.StringWithCtx(evalCtx, errors.RedactLogDisable) + // The option attached to the import statement here are all + // parameters entered by the user. TiDB will process the + // parameters entered by the user as constant. so we can + // directly convert it to constant. + cons := opt.Value.(*expression.Constant) + val := fmt.Sprintf("%v", cons.Value.GetValue()) if opt.Name == cloudStorageURIOption { val = ast.RedactURL(val) } diff --git a/pkg/executor/testdata/prepare_suite_out.json b/pkg/executor/testdata/prepare_suite_out.json index 2f84f976d41ce..69271a0bc6195 100644 --- a/pkg/executor/testdata/prepare_suite_out.json +++ b/pkg/executor/testdata/prepare_suite_out.json @@ -222,7 +222,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -290,7 +290,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -363,7 +363,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -431,7 +431,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -504,7 +504,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:77)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(5,4) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -572,7 +572,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:78)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(5,4) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -645,7 +645,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:79)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(64,30) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -713,7 +713,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:80)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(64,30) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -786,7 +786,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(15,5) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -854,7 +854,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(15,5) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -927,7 +927,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:77)->Column#1", + "Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(5,5) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", @@ -995,7 +995,7 @@ } ], "Plan": [ - "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:78)->Column#1", + "Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(5,5) BINARY)->Column#1", "└─TableDual_4 1.00 root rows:1" ], "LastPlanUseCache": "0", diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 7b2599247f9c2..c39684f3332c3 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -156,9 +156,9 @@ func (c *Constant) StringWithCtx(ctx ParamValues, redact string) string { return c.DeferredExpr.StringWithCtx(ctx, redact) } if redact == perrors.RedactLogDisable { - return fmt.Sprintf("%v", c.Value.GetValue()) + return c.Value.TruncatedStringify() } else if redact == perrors.RedactLogMarker { - return fmt.Sprintf("‹%v›", c.Value.GetValue()) + return fmt.Sprintf("‹%s›", c.Value.TruncatedStringify()) } return "?" } diff --git a/pkg/expression/explain.go b/pkg/expression/explain.go index 6ff136115abd1..763a2b051259a 100644 --- a/pkg/expression/explain.go +++ b/pkg/expression/explain.go @@ -173,9 +173,9 @@ func (expr *Constant) format(dt types.Datum) string { return "NULL" case types.KindString, types.KindBytes, types.KindMysqlEnum, types.KindMysqlSet, types.KindMysqlJSON, types.KindBinaryLiteral, types.KindMysqlBit: - return fmt.Sprintf("\"%v\"", dt.GetValue()) + return fmt.Sprintf("\"%s\"", dt.TruncatedStringify()) } - return fmt.Sprintf("%v", dt.GetValue()) + return dt.TruncatedStringify() } // ExplainExpressionList generates explain information for a list of expressions. @@ -192,13 +192,7 @@ func ExplainExpressionList(ctx EvalContext, exprs []Expression, schema *Schema, } case *Constant: v := expr.StringWithCtx(ctx, errors.RedactLogDisable) - length := 64 - if len(v) < length { - redact.WriteRedact(builder, v, redactMode) - } else { - redact.WriteRedact(builder, v[:length], redactMode) - fmt.Fprintf(builder, "(len:%d)", len(v)) - } + redact.WriteRedact(builder, v, redactMode) builder.WriteString("->") builder.WriteString(schema.Columns[i].StringWithCtx(ctx, redactMode)) default: diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index 1865edc9f1838..fa5d0d093b599 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 45, + shard_count = 46, deps = [ "//pkg/config", "//pkg/domain", @@ -34,6 +34,7 @@ go_test( "//pkg/types", "//pkg/util/codec", "//pkg/util/collate", + "//pkg/util/plancodec", "//pkg/util/sem", "//pkg/util/timeutil", "//pkg/util/versioninfo", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index e6b69fc735d0a..64dd807171fbd 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -53,6 +53,7 @@ import ( "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/plancodec" "github.com/pingcap/tidb/pkg/util/sem" "github.com/pingcap/tidb/pkg/util/versioninfo" "github.com/stretchr/testify/assert" @@ -321,6 +322,67 @@ func TestVectorColumnInfo(t *testing.T) { tk.MustGetErrMsg("create table t(embedding VECTOR(16384))", "vector cannot have more than 16383 dimensions") } +func TestVectorConstantExplain(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("CREATE TABLE t(c VECTOR);") + tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, '[1,2,3,4,5,6,7,8,9,10,11]') FROM t;`).Check(testkit.Rows( + "Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3", + "└─TableReader_5 10000.00 root data:TableFullScan_4", + " └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, VEC_FROM_TEXT('[1,2,3,4,5,6,7,8,9,10,11]')) FROM t;`).Check(testkit.Rows( + "Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3", + "└─TableReader_5 10000.00 root data:TableFullScan_4", + " └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, '[1,2,3,4,5,6,7,8,9,10,11]') AS d FROM t ORDER BY d LIMIT 10;`).Check(testkit.Rows( + "Projection_6 10.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3", + "└─Projection_13 10.00 root test.t.c", + " └─TopN_7 10.00 root Column#4, offset:0, count:10", + " └─Projection_14 10.00 root test.t.c, vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#4", + " └─TableReader_12 10.00 root data:TopN_11", + " └─TopN_11 10.00 cop[tikv] vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...]), offset:0, count:10", + " └─TableFullScan_10 10000.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + + // Prepare a large Vector string + vb := strings.Builder{} + vb.WriteString("[") + for i := 0; i < 100; i++ { + if i > 0 { + vb.WriteString(",") + } + vb.WriteString("100") + } + vb.WriteString("]") + + stmtID, _, _, err := tk.Session().PrepareStmt("SELECT VEC_COSINE_DISTANCE(c, ?) FROM t") + require.Nil(t, err) + rs, err := tk.Session().ExecutePreparedStmt(context.Background(), stmtID, expression.Args2Expressions4Test(vb.String())) + require.NoError(t, err) + + p, ok := tk.Session().GetSessionVars().StmtCtx.GetPlan().(base.Plan) + require.True(t, ok) + + flat := plannercore.FlattenPhysicalPlan(p, true) + encodedPlanTree := plannercore.EncodeFlatPlan(flat) + planTree, err := plancodec.DecodePlan(encodedPlanTree) + require.NoError(t, err) + fmt.Println(planTree) + fmt.Println("++++") + require.Equal(t, strings.Join([]string{ + ` id task estRows operator info actRows execution info memory disk`, + ` Projection_3 root 10000 vec_cosine_distance(test.t.c, cast([100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100...(len:401), vector))->Column#3 0 time:0s, loops:0 0 Bytes N/A`, + ` └─TableReader_5 root 10000 data:TableFullScan_4 0 time:0s, loops:0 0 Bytes N/A`, + ` └─TableFullScan_4 cop[tikv] 10000 table:t, keep order:false, stats:pseudo 0 N/A N/A`, + }, "\n"), planTree) + + // No need to check result at all. + tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)) +} + func TestFixedVector(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index 512bcc4207cac..31ac9b26148e0 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -563,6 +563,7 @@ func (m *MockExpr) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chun } func (m *MockExpr) StringWithCtx(ParamValues, string) string { return "" } + func (m *MockExpr) Eval(ctx EvalContext, row chunk.Row) (types.Datum, error) { return types.NewDatum(m.i), m.err } diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 2bb9ee740b772..5eb5bb873aea1 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -552,6 +552,37 @@ func (d *Datum) GetValue() any { } } +// TruncatedStringify returns the %v representation of the datum +// but truncated (for example, for strings, only first 64 bytes is printed). +// This function is useful in contexts like EXPLAIN. +func (d *Datum) TruncatedStringify() string { + const maxLen = 64 + + switch d.k { + case KindString, KindBytes: + str := d.GetString() + if len(str) > maxLen { + // This efficiently returns the truncated string without + // less possible allocations. + return fmt.Sprintf("%s...(len:%d)", str[:maxLen], len(str)) + } + return str + case KindMysqlJSON: + // For now we can only stringify then truncate. + str := d.GetMysqlJSON().String() + if len(str) > maxLen { + return fmt.Sprintf("%s...(len:%d)", str[:maxLen], len(str)) + } + return str + case KindVectorFloat32: + // Vector supports native efficient truncation. + return d.GetVectorFloat32().TruncatedString() + default: + // For other types, no truncation is needed. + return fmt.Sprintf("%v", d.GetValue()) + } +} + // SetValueWithDefaultCollation sets any kind of value. func (d *Datum) SetValueWithDefaultCollation(val any) { switch x := val.(type) { diff --git a/pkg/types/vector.go b/pkg/types/vector.go index 80560631c962c..127b237ba3abf 100644 --- a/pkg/types/vector.go +++ b/pkg/types/vector.go @@ -16,6 +16,7 @@ package types import ( "encoding/binary" + "fmt" "math" "strconv" "unsafe" @@ -93,6 +94,38 @@ func (v VectorFloat32) Elements() []float32 { return unsafe.Slice((*float32)(unsafe.Pointer(&v.data[4])), l) } +// TruncatedString prints the vector in a truncated form, which is useful for +// outputting in logs or EXPLAIN statements. +func (v VectorFloat32) TruncatedString() string { + const ( + maxDisplayElements = 5 + ) + + truncatedElements := 0 + elements := v.Elements() + + if len(elements) > maxDisplayElements { + truncatedElements = len(elements) - maxDisplayElements + elements = elements[:maxDisplayElements] + } + + buf := make([]byte, 0, 2+v.Len()*2) + buf = append(buf, '[') + for i, v := range elements { + if i > 0 { + buf = append(buf, ","...) + } + buf = strconv.AppendFloat(buf, float64(v), 'g', 2, 32) + } + if truncatedElements > 0 { + buf = append(buf, fmt.Sprintf(",(%d more)...", truncatedElements)...) + } + buf = append(buf, ']') + + // buf is not used elsewhere, so it's safe to just cast to String + return unsafe.String(unsafe.SliceData(buf), len(buf)) +} + // String returns a string representation of the vector, which can be parsed later. func (v VectorFloat32) String() string { elements := v.Elements()