diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index e7424941680..c327947580e 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1562,36 +1562,36 @@ func (op ComparisonExprOperator) ToString() string { } } -func ComparisonExprOperatorFromJson(s string) ComparisonExprOperator { +func ComparisonExprOperatorFromJson(s string) (ComparisonExprOperator, error) { switch s { case EqualStr: - return EqualOp + return EqualOp, nil case JsonLessThanStr: - return LessThanOp + return LessThanOp, nil case JsonGreaterThanStr: - return GreaterThanOp + return GreaterThanOp, nil case JsonLessThanOrEqualStr: - return LessEqualOp + return LessEqualOp, nil case JsonGreaterThanOrEqualStr: - return GreaterEqualOp + return GreaterEqualOp, nil case NotEqualStr: - return NotEqualOp + return NotEqualOp, nil case NullSafeEqualStr: - return NullSafeEqualOp + return NullSafeEqualOp, nil case InStr: - return InOp + return InOp, nil case NotInStr: - return NotInOp + return NotInOp, nil case LikeStr: - return LikeOp + return LikeOp, nil case NotLikeStr: - return NotLikeOp + return NotLikeOp, nil case RegexpStr: - return RegexpOp + return RegexpOp, nil case NotRegexpStr: - return NotRegexpOp + return NotRegexpOp, nil default: - return 0 + return 0, fmt.Errorf("unknown ComparisonExpOperator: %s", s) } } diff --git a/go/vt/vtgate/executor_vexplain_test.go b/go/vt/vtgate/executor_vexplain_test.go index 443370205a9..99eb77c7ed4 100644 --- a/go/vt/vtgate/executor_vexplain_test.go +++ b/go/vt/vtgate/executor_vexplain_test.go @@ -18,7 +18,10 @@ package vtgate import ( "context" + "encoding/json" "fmt" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -115,153 +118,53 @@ func TestSimpleVexplainTrace(t *testing.T) { } func TestVExplainKeys(t *testing.T) { - tests := []struct { - query string - expectedRowString string - }{ - { - query: "select count(*), col2 from music group by col2", - expectedRowString: `{ - "statementType": "SELECT", - "groupingColumns": [ - "music.col2" - ], - "selectColumns": [ - "music.col2" - ] -}`, - }, { - query: "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 > 100 and ue.noLimit = 'foo'", - expectedRowString: `{ - "statementType": "SELECT", - "joinColumns": [ - "user.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "user.col1 gt", - "user_extra.noLimit =" - ] -}`, - }, { - // same as above, but written differently - query: "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 > 100 and ue.user_id = u.id", - expectedRowString: `{ - "statementType": "SELECT", - "joinColumns": [ - "user.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "user.col1 gt", - "user_extra.noLimit =" - ] -}`, - }, - { - query: "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", - expectedRowString: `{ - "statementType": "SELECT", - "groupingColumns": [ - "user.foo", - "user_extra.bar" - ], - "joinColumns": [ - "user.id =", - "user_extra.user_id =" - ], - "filterColumns": [ - "user.name =" - ], - "selectColumns": [ - "user.foo", - "user_extra.bar" - ] -}`, - }, - { - query: "select * from (select * from user) as derived where derived.amount > 1000", - expectedRowString: `{ - "statementType": "SELECT" -}`, - }, - { - query: "select name, sum(amount) from user group by name", - expectedRowString: `{ - "statementType": "SELECT", - "groupingColumns": [ - "user.name" - ], - "selectColumns": [ - "user.amount", - "user.name" - ] -}`, - }, - { - query: "select name from user where age > 30", - expectedRowString: `{ - "statementType": "SELECT", - "filterColumns": [ - "user.age gt" - ], - "selectColumns": [ - "user.name" - ] -}`, - }, - { - query: "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", - expectedRowString: `{ - "statementType": "SELECT", - "filterColumns": [ - "user.name =", - "user_extra.name =" - ] -}`, - }, - { - query: "update user set name = 'Jane Doe' where id = 1", - expectedRowString: `{ - "statementType": "UPDATE", - "filterColumns": [ - "user.id =" - ] -}`, - }, - { - query: "delete from user where order_date < '2023-01-01'", - expectedRowString: `{ - "statementType": "DELETE", - "filterColumns": [ - "user.order_date lt" - ] -}`, - }, - { - query: "select * from user where name between 'A' and 'C'", - expectedRowString: `{ - "statementType": "SELECT", - "filterColumns": [ - "user.name ge", - "user.name le" - ] -}`, - }, + type testCase struct { + Query string `json:"query"` + Expected json.RawMessage `json:"expected"` } + var tests []testCase + data, err := os.ReadFile("testdata/executor_vexplain.json") + require.NoError(t, err) + + err = json.Unmarshal(data, &tests) + require.NoError(t, err) + + var updatedTests []testCase + for _, tt := range tests { - t.Run(tt.query, func(t *testing.T) { + t.Run(tt.Query, func(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) - gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.query, nil) + gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.Query, nil) require.NoError(t, err) gotRowString := gotResult.Rows[0][0].ToString() - assert.Equal(t, tt.expectedRowString, gotRowString) + assert.JSONEq(t, string(tt.Expected), gotRowString) + + updatedTests = append(updatedTests, testCase{ + Query: tt.Query, + Expected: json.RawMessage(gotRowString), + }) + if t.Failed() { - fmt.Println(gotRowString) + fmt.Println("Test failed for query:", tt.Query) + fmt.Println("Got result:", gotRowString) } }) } + + // If anything failed, write the updated test cases to a temp file + if t.Failed() { + tempFilePath := filepath.Join(os.TempDir(), "updated_vexplain_keys_tests.json") + fmt.Println("Writing updated tests to:", tempFilePath) + + updatedTestsData, err := json.MarshalIndent(updatedTests, "", "\t") + require.NoError(t, err) + + err = os.WriteFile(tempFilePath, updatedTestsData, 0644) + require.NoError(t, err) + + fmt.Println("Updated tests written to:", tempFilePath) + } } diff --git a/go/vt/vtgate/planbuilder/operators/keys.go b/go/vt/vtgate/planbuilder/operators/keys.go index c16d9b23b63..f5b592b2291 100644 --- a/go/vt/vtgate/planbuilder/operators/keys.go +++ b/go/vt/vtgate/planbuilder/operators/keys.go @@ -77,14 +77,37 @@ func (cu *ColumnUse) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &s); err != nil { return err } - parts := strings.Fields(s) - if len(parts) != 2 { + spaceIdx := strings.LastIndex(s, " ") + if spaceIdx == -1 { return fmt.Errorf("invalid ColumnUse format: %s", s) } - if err := cu.Column.UnmarshalJSON([]byte(`"` + parts[0] + `"`)); err != nil { - return err + + for i := spaceIdx - 1; i >= 0; i-- { + // table.column not like + // table.`tricky not` like + if s[i] == '`' || s[i] == '.' { + break + } + if s[i] == ' ' { + spaceIdx = i + break + } + if i == 0 { + return fmt.Errorf("invalid ColumnUse format: %s", s) + } + } + + colStr, opStr := s[:spaceIdx], s[spaceIdx+1:] + + err := cu.Column.UnmarshalJSON([]byte(`"` + colStr + `"`)) + if err != nil { + return fmt.Errorf("failed to unmarshal column: %w", err) + } + + cu.Uses, err = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(opStr)) + if err != nil { + return fmt.Errorf("failed to unmarshal operator: %w", err) } - cu.Uses = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(parts[1])) return nil } @@ -209,5 +232,9 @@ func createColumn(ctx *plancontext.PlanningContext, col *sqlparser.ColName) *Col if table == nil { return nil } - return &Column{Table: table.Name.String(), Name: col.Name.String()} + return &Column{ + // we want the escaped versions of the names + Table: sqlparser.String(table.Name), + Name: sqlparser.String(col.Name), + } } diff --git a/go/vt/vtgate/planbuilder/operators/keys_test.go b/go/vt/vtgate/planbuilder/operators/keys_test.go index 6f53a33da5c..5c60e62c70c 100644 --- a/go/vt/vtgate/planbuilder/operators/keys_test.go +++ b/go/vt/vtgate/planbuilder/operators/keys_test.go @@ -32,7 +32,7 @@ func TestMarshalUnmarshal(t *testing.T) { StatementType: "SELECT", TableName: []string{"users", "orders"}, GroupingColumns: []Column{ - {Table: "", Name: "category"}, + {Table: "orders", Name: "category"}, {Table: "users", Name: "department"}, }, JoinColumns: []ColumnUse{ @@ -40,12 +40,13 @@ func TestMarshalUnmarshal(t *testing.T) { {Column: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp}, }, FilterColumns: []ColumnUse{ - {Column: Column{Table: "", Name: "age"}, Uses: sqlparser.GreaterThanOp}, + {Column: Column{Table: "users", Name: "age"}, Uses: sqlparser.GreaterThanOp}, {Column: Column{Table: "orders", Name: "total"}, Uses: sqlparser.LessThanOp}, + {Column: Column{Table: "orders", Name: "`tricky name not`"}, Uses: sqlparser.InOp}, }, SelectColumns: []Column{ {Table: "users", Name: "name"}, - {Table: "", Name: "email"}, + {Table: "users", Name: "email"}, {Table: "orders", Name: "amount"}, }, } diff --git a/go/vt/vtgate/testdata/executor_vexplain.json b/go/vt/vtgate/testdata/executor_vexplain.json new file mode 100644 index 00000000000..5b70354f158 --- /dev/null +++ b/go/vt/vtgate/testdata/executor_vexplain.json @@ -0,0 +1,132 @@ +[ + { + "query": "select count(*), col2 from music group by col2", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "music.col2" + ], + "selectColumns": [ + "music.col2" + ] + } + }, + { + "query": "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 \u003e 100 and ue.noLimit = 'foo'", + "expected": { + "statementType": "SELECT", + "joinColumns": [ + "`user`.id =", + "user_extra.user_id =" + ], + "filterColumns": [ + "`user`.col1 gt", + "user_extra.noLimit =" + ] + } + }, + { + "query": "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 \u003e 100 and ue.user_id = u.id", + "expected": { + "statementType": "SELECT", + "joinColumns": [ + "`user`.id =", + "user_extra.user_id =" + ], + "filterColumns": [ + "`user`.col1 gt", + "user_extra.noLimit =" + ] + } + }, + { + "query": "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "`user`.foo", + "user_extra.bar" + ], + "joinColumns": [ + "`user`.id =", + "user_extra.user_id =" + ], + "filterColumns": [ + "`user`.`name` =" + ], + "selectColumns": [ + "`user`.foo", + "user_extra.bar" + ] + } + }, + { + "query": "select * from (select * from user) as derived where derived.amount \u003e 1000", + "expected": { + "statementType": "SELECT" + } + }, + { + "query": "select name, sum(amount) from user group by name", + "expected": { + "statementType": "SELECT", + "groupingColumns": [ + "`user`.`name`" + ], + "selectColumns": [ + "`user`.`name`", + "`user`.amount" + ] + } + }, + { + "query": "select name from user where age \u003e 30", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.age gt" + ], + "selectColumns": [ + "`user`.`name`" + ] + } + }, + { + "query": "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.`name` =", + "user_extra.`name` =" + ] + } + }, + { + "query": "update user set name = 'Jane Doe' where id = 1", + "expected": { + "statementType": "UPDATE", + "filterColumns": [ + "`user`.id =" + ] + } + }, + { + "query": "delete from user where order_date \u003c '2023-01-01'", + "expected": { + "statementType": "DELETE", + "filterColumns": [ + "`user`.order_date lt" + ] + } + }, + { + "query": "select * from user where name between 'A' and 'C'", + "expected": { + "statementType": "SELECT", + "filterColumns": [ + "`user`.`name` ge", + "`user`.`name` le" + ] + } + } +] \ No newline at end of file