diff --git a/go/test/endtoend/vtgate/queries/random/query_gen.go b/go/test/endtoend/vtgate/queries/random/query_gen.go index 9415560b132..c3a1b91ae36 100644 --- a/go/test/endtoend/vtgate/queries/random/query_gen.go +++ b/go/test/endtoend/vtgate/queries/random/query_gen.go @@ -19,254 +19,375 @@ package random import ( "fmt" "math/rand" - "time" + + "vitess.io/vitess/go/slice" "golang.org/x/exp/slices" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" ) // this file contains the structs and functions to generate random queries +// to test only a particular type of query, delete the corresponding testFailingQueries clause +// there should be a comment indicating the type of query being disabled // if true then known failing query types are still generated by randomQuery() const testFailingQueries = false type ( + // selectGenerator generates select statements + selectGenerator struct { + r *rand.Rand + genConfig sqlparser.ExprGeneratorConfig + maxTables int + maxAggrs int + maxGBs int + schemaTables []tableT + sel *sqlparser.Select + } + + // queryGenerator generates queries, which can either be unions or select statements + queryGenerator struct { + stmt sqlparser.SelectStatement + selGen *selectGenerator + } + column struct { + name string + // TODO: perhaps remove tableName and always pass columns through a tableT tableName string - name string typ string } + tableT struct { // the tableT struct can be used to represent the schema of a table or a derived table - // in the former case name will be a sqlparser.TableName, in the latter a sqlparser.DerivedTable - // in order to create a query with a derived table, its AST form is retrieved from name - // once the derived table is aliased, name is replaced by a sqlparser.TableName with that alias - name sqlparser.SimpleTableExpr - cols []column + // in the former case tableExpr will be a sqlparser.TableName, in the latter a sqlparser.DerivedTable + // in order to create a query with a derived table, its AST form is retrieved from tableExpr + // once the derived table is aliased, alias is updated + tableExpr sqlparser.SimpleTableExpr + alias string + cols []column } ) var _ sqlparser.ExprGenerator = (*tableT)(nil) +var _ sqlparser.ExprGenerator = (*column)(nil) +var _ sqlparser.QueryGenerator = (*selectGenerator)(nil) +var _ sqlparser.QueryGenerator = (*queryGenerator)(nil) + +func newQueryGenerator(r *rand.Rand, genConfig sqlparser.ExprGeneratorConfig, maxTables, maxAggrs, maxGBs int, schemaTables []tableT) *queryGenerator { + return &queryGenerator{ + selGen: newSelectGenerator(r, genConfig, maxTables, maxAggrs, maxGBs, schemaTables), + } +} -func (t *tableT) typeExpr(typ string) sqlparser.Expr { - tableCopy := t.clone() +func newSelectGenerator(r *rand.Rand, genConfig sqlparser.ExprGeneratorConfig, maxTables, maxAggrs, maxGBs int, schemaTables []tableT) *selectGenerator { + if maxTables <= 0 { + log.Fatalf("maxTables must be at least 1, currently %d\n", maxTables) + } - for len(tableCopy.cols) > 0 { - idx := rand.Intn(len(tableCopy.cols)) - randCol := tableCopy.cols[idx] - if randCol.typ == typ { - newTableName := "" - if tName, ok := tableCopy.name.(sqlparser.TableName); ok { - newTableName = sqlparser.String(tName.Name) - } - return sqlparser.NewColNameWithQualifier(randCol.name, sqlparser.NewTableName(newTableName)) - } + return &selectGenerator{ + r: r, + genConfig: genConfig, + maxTables: maxTables, + maxAggrs: maxAggrs, + maxGBs: maxGBs, + schemaTables: schemaTables, + sel: &sqlparser.Select{}, + } +} - // delete randCol from table.columns - tableCopy.cols[idx] = tableCopy.cols[len(tableCopy.cols)-1] - tableCopy.cols = tableCopy.cols[:len(tableCopy.cols)-1] +// getColumnName returns tableName.name (if tableName is nonempty), otherwise name +func (c *column) getColumnName() string { + var columnName string + if c.tableName != "" { + columnName += c.tableName + "." } - return nil + return columnName + c.name } -func (t *tableT) IntExpr() sqlparser.Expr { - // TODO: better way to check if int type? - return t.typeExpr("bigint") +// getASTExpr returns the AST representation of a column +func (c *column) getASTExpr() sqlparser.Expr { + return sqlparser.NewColNameWithQualifier(c.name, sqlparser.NewTableName(c.tableName)) } -func (t *tableT) StringExpr() sqlparser.Expr { - return t.typeExpr("varchar") +// getName returns the alias if it is nonempty +// if the alias is nonempty and tableExpr is of type sqlparser.TableName, +// then getName returns Name from tableExpr +// otherwise getName returns an empty string +func (t *tableT) getName() string { + if t.alias != "" { + return t.alias + } else if tName, ok := t.tableExpr.(sqlparser.TableName); ok { + return sqlparser.String(tName.Name) + } + + return "" } -// setName sets the alias for t, as well as setting the tableName for all columns in cols -func (t *tableT) setName(newName string) { - t.name = sqlparser.NewTableName(newName) +// setAlias sets the alias for t, as well as setting the tableName for all columns in cols +func (t *tableT) setAlias(newName string) { + t.alias = newName for i := range t.cols { t.cols[i].tableName = newName } } -// setColumns sets the columns of t, and automatically assigns tableName -// this makes it unnatural (but still possible as cols is exportable) to modify tableName -func (t *tableT) setColumns(col ...column) { - t.cols = nil - t.addColumns(col...) -} - -// addColumns adds columns to t, and automatically assigns tableName -// this makes it unnatural (but still possible as cols is exportable) to modify tableName +// addColumns adds columns to t, and automatically assigns each column.tableName +// this makes it unnatural to modify tableName func (t *tableT) addColumns(col ...column) { for i := range col { - // only change the Col's tableName if t is of type tableName - if tName, ok := t.name.(sqlparser.TableName); ok { - col[i].tableName = sqlparser.String(tName.Name) - } - + col[i].tableName = t.getName() t.cols = append(t.cols, col[i]) } } -// clone returns a deep copy of t func (t *tableT) clone() *tableT { return &tableT{ - name: t.name, - cols: slices.Clone(t.cols), + tableExpr: sqlparser.CloneSimpleTableExpr(t.tableExpr), + alias: t.alias, + cols: slices.Clone(t.cols), } } -// getColumnName returns tableName.name -func (c *column) getColumnName() string { - return fmt.Sprintf("%s.%s", c.tableName, c.name) +func (c *column) Generate(_ *rand.Rand, genConfig sqlparser.ExprGeneratorConfig) sqlparser.Expr { + if c.typ == genConfig.Type || genConfig.Type == "" { + return c.getASTExpr() + } + + return nil +} + +func (t *tableT) Generate(r *rand.Rand, genConfig sqlparser.ExprGeneratorConfig) sqlparser.Expr { + colsCopy := slices.Clone(t.cols) + + for len(colsCopy) > 0 { + idx := r.Intn(len(colsCopy)) + randCol := colsCopy[idx] + if randCol.typ == genConfig.Type || genConfig.Type == "" { + return randCol.getASTExpr() + } + + // delete randCol from colsCopy + colsCopy[idx] = colsCopy[len(colsCopy)-1] + colsCopy = colsCopy[:len(colsCopy)-1] + } + + return nil +} + +// Generate generates a subquery based on sg +// TODO: currently unused; generate random expressions with union +func (sg *selectGenerator) Generate(r *rand.Rand, genConfig sqlparser.ExprGeneratorConfig) sqlparser.Expr { + var schemaTablesCopy []tableT + for _, tbl := range sg.schemaTables { + schemaTablesCopy = append(schemaTablesCopy, *tbl.clone()) + } + + newSG := newQueryGenerator(r, genConfig, sg.maxTables, sg.maxAggrs, sg.maxGBs, schemaTablesCopy) + newSG.randomQuery() + + return &sqlparser.Subquery{Select: newSG.selGen.sel} +} + +// Generate generates a subquery based on qg +func (qg *queryGenerator) Generate(r *rand.Rand, genConfig sqlparser.ExprGeneratorConfig) sqlparser.Expr { + var schemaTablesCopy []tableT + for _, tbl := range qg.selGen.schemaTables { + schemaTablesCopy = append(schemaTablesCopy, *tbl.clone()) + } + + newQG := newQueryGenerator(r, genConfig, qg.selGen.maxTables, qg.selGen.maxAggrs, qg.selGen.maxGBs, schemaTablesCopy) + newQG.randomQuery() + + return &sqlparser.Subquery{Select: newQG.stmt} +} + +func (sg *selectGenerator) IsQueryGenerator() {} +func (qg *queryGenerator) IsQueryGenerator() {} + +func (qg *queryGenerator) randomQuery() { + if qg.selGen.r.Intn(10) < 1 && testFailingQueries { + qg.createUnion() + } else { + qg.selGen.randomSelect() + qg.stmt = qg.selGen.sel + } +} + +// createUnion creates a simple UNION or UNION ALL; no LIMIT or ORDER BY +func (qg *queryGenerator) createUnion() { + union := &sqlparser.Union{} + + if qg.selGen.r.Intn(2) < 1 { + union.Distinct = true + } + + // specify between 1-4 columns + qg.selGen.genConfig.NumCols = qg.selGen.r.Intn(4) + 1 + + qg.randomQuery() + union.Left = qg.stmt + qg.randomQuery() + union.Right = qg.stmt + + qg.stmt = union } -func randomQuery(schemaTables []tableT, maxAggrs, maxGroupBy int) *sqlparser.Select { - sel := &sqlparser.Select{} - sel.SetComments(sqlparser.Comments{"/*vt+ PLANNER=Gen4 */"}) +func (sg *selectGenerator) randomSelect() { + // make sure the random expressions can generally not contain aggregates; change appropriately + sg.genConfig = sg.genConfig.CannotAggregateConfig() + + sg.sel = &sqlparser.Select{} + sg.sel.SetComments(sqlparser.Comments{"/*vt+ PLANNER=Gen4 */"}) // select distinct (fails with group by bigint) - isDistinct := rand.Intn(2) < 1 + isDistinct := sg.r.Intn(2) < 1 if isDistinct { - sel.MakeDistinct() + sg.sel.MakeDistinct() } // create both tables and join at the same time since both occupy the from clause - tables, isJoin := createTablesAndJoin(schemaTables, sel) + tables, isJoin := sg.createTablesAndJoin() + + // canAggregate determines if the query will have + // aggregate columns, group by, and having + canAggregate := sg.r.Intn(4) < 3 var ( - groupBy sqlparser.GroupBy - groupSelectExprs sqlparser.SelectExprs - grouping []column + grouping, aggregates []column + newTable tableT ) // TODO: distinct makes vitess think there is grouping on aggregation columns - if testFailingQueries || !isDistinct { - groupBy, groupSelectExprs, grouping = createGroupBy(tables, maxGroupBy) - sel.AddSelectExprs(groupSelectExprs) - sel.GroupBy = groupBy - - } + if canAggregate { + if testFailingQueries || !isDistinct { + // group by + if !sg.genConfig.SingleRow { + grouping = sg.createGroupBy(tables) + } + } - aggrExprs, aggregates := createAggregations(tables, maxAggrs) - sel.AddSelectExprs(aggrExprs) + // having + isHaving := sg.r.Intn(2) < 1 + // TODO: having creates a lot of results mismatched + if isHaving && testFailingQueries { + sg.createHavingPredicates(grouping) + } - // can add both aggregate and grouping columns to order by - // TODO: order fails with distinct and outer joins - isOrdered := rand.Intn(2) < 1 && (!isDistinct || testFailingQueries) && (!isJoin || testFailingQueries) && testFailingQueries - // TODO: order by fails a lot; probably related to the previously passing query - // TODO: should be fixed soon - if isOrdered { - sel.OrderBy = createOrderBy(groupBy, aggrExprs) - } + // alias the grouping columns + grouping = sg.aliasGroupingColumns(grouping) - // where - sel.AddWhere(sqlparser.AndExpressions(createWherePredicates(tables, false)...)) + // aggregation columns + aggregates = sg.createAggregations(tables) - // random predicate expression - // TODO: random expressions cause a lot of failures - if rand.Intn(2) < 1 && testFailingQueries { - predRandomExpr := getRandomExpr(tables) - sel.AddWhere(predRandomExpr) - } - - // having - isHaving := rand.Intn(2) < 1 - if isHaving { - sel.AddHaving(sqlparser.AndExpressions(createHavingPredicates(tables)...)) - if rand.Intn(2) < 1 && testFailingQueries { - // TODO: having can only contain aggregate or grouping columns in mysql, works fine in vitess - // TODO: Can fix this by putting only the table with the grouping and aggregates column (newTable) - sel.AddHaving(sqlparser.AndExpressions(createWherePredicates(tables, false)...)) - } + // add the grouping and aggregation to newTable + newTable.addColumns(grouping...) + newTable.addColumns(aggregates...) } - // TODO: use sqlparser.ExprGenerator to generate a random expression with aggregation functions - // only add a limit if the grouping columns are ordered - // TODO: limit fails a lot - if rand.Intn(2) < 1 && (isOrdered || len(groupBy) == 0) && testFailingQueries { - sel.Limit = createLimit() - } + // where + sg.createWherePredicates(tables) - var newTable tableT // add random expression to select // TODO: random expressions cause a lot of failures - isRandomExpr := rand.Intn(2) < 1 && testFailingQueries - var ( - randomExpr sqlparser.Expr - typ string - ) + isRandomExpr := sg.r.Intn(2) < 1 && testFailingQueries + // TODO: selecting a random expression potentially with columns creates // TODO: only_full_group_by related errors in Vitess - if testFailingQueries { - randomExpr = getRandomExpr(tables) - } else { - randomExpr = getRandomExpr(nil) + var exprGenerators []sqlparser.ExprGenerator + if canAggregate && testFailingQueries { + exprGenerators = slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t }) + // add scalar subqueries + if sg.r.Intn(10) < 1 { + exprGenerators = append(exprGenerators, sg) + } } // make sure we have at least one select expression - if isRandomExpr || len(sel.SelectExprs) == 0 { + for isRandomExpr || len(sg.sel.SelectExprs) == 0 { + // TODO: if the random expression is an int literal, + // TODO: and if the query is (potentially) an aggregate query, + // TODO: then we must group by the random expression, + // TODO: but we cannot do this for int literals, + // TODO: so we loop until we get a non-int-literal random expression + // TODO: this is necessary because grouping by the alias (crandom0) currently fails on vitess + randomExpr := sg.getRandomExpr(exprGenerators...) + literal, ok := randomExpr.(*sqlparser.Literal) + isIntLiteral := ok && literal.Type == sqlparser.IntVal + if isIntLiteral && canAggregate { + continue + } + // TODO: select distinct [literal] fails - sel.Distinct = false + sg.sel.Distinct = false - sel.SelectExprs = append(sel.SelectExprs, sqlparser.NewAliasedExpr(randomExpr, "crandom0")) - newTable.addColumns(column{ - name: "crandom0", - typ: typ, - }) + // alias randomly + col := sg.randomlyAlias(randomExpr, "crandom0") + newTable.addColumns(col) // make sure to add the random expression to group by for only_full_group_by - sel.AddGroupBy(randomExpr) + if canAggregate { + sg.sel.AddGroupBy(randomExpr) + } + + break + } + + // can add both aggregate and grouping columns to order by + // TODO: order fails with distinct and outer joins + isOrdered := sg.r.Intn(2) < 1 && (!isDistinct || testFailingQueries) && (!isJoin || testFailingQueries) + if isOrdered || (!canAggregate && sg.genConfig.SingleRow) /* TODO: might be redundant */ { + sg.createOrderBy() } - // add them to newTable - newTable.addColumns(grouping...) - newTable.addColumns(aggregates...) + // only add a limit if there is an ordering + // TODO: limit fails a lot + isLimit := sg.r.Intn(2) < 1 && len(sg.sel.OrderBy) > 0 && testFailingQueries + if isLimit || (!canAggregate && sg.genConfig.SingleRow) /* TODO: might be redundant */ { + sg.createLimit() + } + + // this makes sure the query generated has the correct number of columns (sg.selGen.genConfig.numCols) + newTable = sg.matchNumCols(tables, newTable, canAggregate) // add new table to schemaTables - newTable.name = sqlparser.NewDerivedTable(false, sel) - schemaTables = append(schemaTables, newTable) + newTable.tableExpr = sqlparser.NewDerivedTable(false, sg.sel) + sg.schemaTables = append(sg.schemaTables, newTable) // derived tables (partially unsupported) - // TODO: derived tables fails a lot - if rand.Intn(10) < 1 && testFailingQueries { - sel = randomQuery(schemaTables, 3, 3) + if sg.r.Intn(10) < 1 { + sg.randomSelect() } - - return sel } -func createTablesAndJoin(schemaTables []tableT, sel *sqlparser.Select) ([]tableT, bool) { +func (sg *selectGenerator) createTablesAndJoin() ([]tableT, bool) { var tables []tableT - // add at least one of original emp/dept tables for now because derived tables have nil columns - tables = append(tables, schemaTables[rand.Intn(2)]) + // add at least one of original emp/dept tables + tables = append(tables, sg.schemaTables[sg.r.Intn(2)]) - sel.From = append(sel.From, newAliasedTable(tables[0], "tbl0")) - tables[0].setName("tbl0") + tables[0].setAlias("tbl0") + sg.sel.From = append(sg.sel.From, newAliasedTable(tables[0], "tbl0")) - numTables := rand.Intn(len(schemaTables)) + numTables := sg.r.Intn(sg.maxTables) for i := 0; i < numTables; i++ { - tables = append(tables, randomEl(schemaTables)) - sel.From = append(sel.From, newAliasedTable(tables[i+1], fmt.Sprintf("tbl%d", i+1))) - tables[i+1].setName(fmt.Sprintf("tbl%d", i+1)) + tables = append(tables, randomEl(sg.r, sg.schemaTables)) + alias := fmt.Sprintf("tbl%d", i+1) + sg.sel.From = append(sg.sel.From, newAliasedTable(tables[i+1], alias)) + tables[i+1].setAlias(alias) } - // TODO: outer joins produce mismatched results - isJoin := rand.Intn(2) < 1 && testFailingQueries + // TODO: outer joins produce results mismatched + isJoin := sg.r.Intn(2) < 1 && testFailingQueries if isJoin { - newTable := randomEl(schemaTables) + // TODO: do nested joins + newTable := randomEl(sg.r, sg.schemaTables) + alias := fmt.Sprintf("tbl%d", numTables+1) + newTable.setAlias(alias) tables = append(tables, newTable) - // create the join before aliasing - newJoinTableExpr := createJoin(tables, sel) - - // alias - tables[numTables+1].setName(fmt.Sprintf("tbl%d", numTables+1)) - - // create the condition after aliasing - newJoinTableExpr.Condition = sqlparser.NewJoinCondition(sqlparser.AndExpressions(createWherePredicates(tables, true)...), nil) - sel.From[numTables] = newJoinTableExpr + sg.createJoin(tables) } return tables, isJoin @@ -274,32 +395,53 @@ func createTablesAndJoin(schemaTables []tableT, sel *sqlparser.Select) ([]tableT // creates a left join (without the condition) between the last table in sel and newTable // tables should have one more table than sel -func createJoin(tables []tableT, sel *sqlparser.Select) *sqlparser.JoinTableExpr { - n := len(sel.From) +func (sg *selectGenerator) createJoin(tables []tableT) { + n := len(sg.sel.From) if len(tables) != n+1 { - log.Fatalf("sel has %d tables and tables has %d tables", len(sel.From), n) + log.Fatalf("sel has %d tables and tables has %d tables", len(sg.sel.From), n) } - return sqlparser.NewJoinTableExpr(sel.From[n-1], sqlparser.LeftJoinType, newAliasedTable(tables[n], fmt.Sprintf("tbl%d", n)), nil) + joinPredicate := sqlparser.AndExpressions(sg.createJoinPredicates(tables)...) + joinCondition := sqlparser.NewJoinCondition(joinPredicate, nil) + newTable := newAliasedTable(tables[n], fmt.Sprintf("tbl%d", n)) + sg.sel.From[n-1] = sqlparser.NewJoinTableExpr(sg.sel.From[n-1], getRandomJoinType(sg.r), newTable, joinCondition) } -// returns the grouping columns as three types: sqlparser.GroupBy, sqlparser.SelectExprs, []column -func createGroupBy(tables []tableT, maxGB int) (groupBy sqlparser.GroupBy, groupSelectExprs sqlparser.SelectExprs, grouping []column) { - numGBs := rand.Intn(maxGB) +// returns 1-3 random expressions based on the last two elements of tables +// tables should have at least two elements +func (sg *selectGenerator) createJoinPredicates(tables []tableT) sqlparser.Exprs { + if len(tables) < 2 { + log.Fatalf("tables has %d elements, needs at least 2", len(tables)) + } + + exprGenerators := []sqlparser.ExprGenerator{&tables[len(tables)-2], &tables[len(tables)-1]} + // add scalar subqueries + // TODO: subqueries fail + if sg.r.Intn(10) < 1 && testFailingQueries { + exprGenerators = append(exprGenerators, sg) + } + + return sg.createRandomExprs(1, 3, exprGenerators...) +} + +// returns the grouping columns as []column +func (sg *selectGenerator) createGroupBy(tables []tableT) (grouping []column) { + if sg.maxGBs <= 0 { + return + } + numGBs := sg.r.Intn(sg.maxGBs + 1) for i := 0; i < numGBs; i++ { - tblIdx := rand.Intn(len(tables)) - col := randomEl(tables[tblIdx].cols) + tblIdx := sg.r.Intn(len(tables)) + col := randomEl(sg.r, tables[tblIdx].cols) // TODO: grouping by a date column sometimes errors if col.typ == "date" && !testFailingQueries { continue } - groupBy = append(groupBy, newColumn(col)) + sg.sel.GroupBy = append(sg.sel.GroupBy, col.getASTExpr()) // add to select - if rand.Intn(2) < 1 { - groupSelectExprs = append(groupSelectExprs, newAliasedColumn(col, fmt.Sprintf("cgroup%d", i))) - // TODO: alias in a separate function to properly generate the having clause - col.name = fmt.Sprintf("cgroup%d", i) + if sg.r.Intn(2) < 1 { + sg.sel.SelectExprs = append(sg.sel.SelectExprs, newAliasedColumn(col, "")) grouping = append(grouping, col) } } @@ -307,184 +449,203 @@ func createGroupBy(tables []tableT, maxGB int) (groupBy sqlparser.GroupBy, group return } -// returns the aggregation columns as three types: sqlparser.SelectExprs, []column -func createAggregations(tables []tableT, maxAggrs int) (aggrExprs sqlparser.SelectExprs, aggregates []column) { - aggregations := []func(col column) sqlparser.Expr{ - func(_ column) sqlparser.Expr { return &sqlparser.CountStar{} }, - func(col column) sqlparser.Expr { return &sqlparser.Count{Args: sqlparser.Exprs{newColumn(col)}} }, - func(col column) sqlparser.Expr { return &sqlparser.Sum{Arg: newColumn(col)} }, - // func(col column) sqlparser.Expr { return &sqlparser.Avg{Arg: newAggregateExpr(col)} }, - func(col column) sqlparser.Expr { return &sqlparser.Min{Arg: newColumn(col)} }, - func(col column) sqlparser.Expr { return &sqlparser.Max{Arg: newColumn(col)} }, - } - - numAggrs := rand.Intn(maxAggrs) - for i := 0; i < numAggrs; i++ { - tblIdx, aggrIdx := rand.Intn(len(tables)), rand.Intn(len(aggregations)) - col := randomEl(tables[tblIdx].cols) - // TODO: aggregating on a date column sometimes errors - if col.typ == "date" && !testFailingQueries { - i-- - continue - } +// aliasGroupingColumns randomly aliases the grouping columns in the SelectExprs +func (sg *selectGenerator) aliasGroupingColumns(grouping []column) []column { + if len(grouping) != len(sg.sel.SelectExprs) { + log.Fatalf("grouping (length: %d) and sg.sel.SelectExprs (length: %d) should have the same length at this point", len(grouping), len(sg.sel.SelectExprs)) + } - newAggregate := aggregations[aggrIdx](col) - // TODO: collating on strings sometimes errors - if col.typ == "varchar" && !testFailingQueries { - switch newAggregate.(type) { - case *sqlparser.Min, *sqlparser.Max: - i-- - continue + for i := range grouping { + if sg.r.Intn(2) < 1 { + if aliasedExpr, ok := sg.sel.SelectExprs[i].(*sqlparser.AliasedExpr); ok { + alias := fmt.Sprintf("cgroup%d", i) + aliasedExpr.SetAlias(alias) + grouping[i].name = alias } } + } - // TODO: type of sum() is incorrect (int64 vs decimal) in certain queries - if _, ok := newAggregate.(*sqlparser.Sum); ok && !testFailingQueries { - i-- - continue - } + return grouping +} - aggrExprs = append(aggrExprs, sqlparser.NewAliasedExpr(newAggregate, fmt.Sprintf("caggr%d", i))) +// returns the aggregation columns as three types: sqlparser.SelectExprs, []column +func (sg *selectGenerator) createAggregations(tables []tableT) (aggregates []column) { + exprGenerators := slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t }) + // add scalar subqueries + // TODO: subqueries fail + if sg.r.Intn(10) < 1 && testFailingQueries { + exprGenerators = append(exprGenerators, sg) + } - if aggrIdx <= 1 /* CountStar and Count */ { - col.typ = "bigint" - } else if _, ok := newAggregate.(*sqlparser.Avg); ok && col.getColumnName() == "bigint" { - col.typ = "decimal" - } + sg.genConfig = sg.genConfig.IsAggregateConfig() + aggrExprs := sg.createRandomExprs(0, sg.maxAggrs, exprGenerators...) + sg.genConfig = sg.genConfig.CannotAggregateConfig() - col.name = fmt.Sprintf("caggr%d", i) + for i, expr := range aggrExprs { + col := sg.randomlyAlias(expr, fmt.Sprintf("caggr%d", i)) aggregates = append(aggregates, col) } + return } -// orders on all non-aggregate SelectExprs and independently at random on all aggregate SelectExprs of sel -func createOrderBy(groupBy sqlparser.GroupBy, aggrExprs sqlparser.SelectExprs) (orderBy sqlparser.OrderBy) { - // always order on grouping columns - for i := range groupBy { - orderBy = append(orderBy, sqlparser.NewOrder(groupBy[i], getRandomOrderDirection())) +// orders on all grouping expressions and on random SelectExprs +func (sg *selectGenerator) createOrderBy() { + // always order on grouping expressions + for _, expr := range sg.sel.GroupBy { + sg.sel.OrderBy = append(sg.sel.OrderBy, sqlparser.NewOrder(expr, getRandomOrderDirection(sg.r))) } - // randomly order on aggregation columns - for i := range aggrExprs { - if aliasedExpr, ok := aggrExprs[i].(*sqlparser.AliasedExpr); ok && rand.Intn(2) < 1 { - orderBy = append(orderBy, sqlparser.NewOrder(aliasedExpr.Expr, getRandomOrderDirection())) + // randomly order on SelectExprs + for _, selExpr := range sg.sel.SelectExprs { + if aliasedExpr, ok := selExpr.(*sqlparser.AliasedExpr); ok && sg.r.Intn(2) < 1 { + literal, ok := aliasedExpr.Expr.(*sqlparser.Literal) + isIntLiteral := ok && literal.Type == sqlparser.IntVal + if isIntLiteral { + continue + } + sg.sel.OrderBy = append(sg.sel.OrderBy, sqlparser.NewOrder(aliasedExpr.Expr, getRandomOrderDirection(sg.r))) } } - - return } -// compares two random columns (usually of the same type) -// returns a random expression if there are no other predicates and isJoin is true -// returns the predicates as a sqlparser.Exprs (slice of sqlparser.Expr's) -func createWherePredicates(tables []tableT, isJoin bool) (predicates sqlparser.Exprs) { - // if creating predicates for a join, - // then make sure predicates are created for the last two tables (which are being joined) - incr := 0 - if isJoin && len(tables) > 2 { - incr += len(tables) - 2 +// returns 0-2 random expressions based on tables +func (sg *selectGenerator) createWherePredicates(tables []tableT) { + exprGenerators := slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t }) + // add scalar subqueries + // TODO: subqueries fail + if sg.r.Intn(10) < 1 && testFailingQueries { + exprGenerators = append(exprGenerators, sg) } - for idx1 := range tables { - for idx2 := range tables { - // fmt.Printf("predicate tables:\n%v\n idx1: %d idx2: %d, incr: %d", tables, idx1, idx2, incr) - if idx1 >= idx2 || idx1 < incr || idx2 < incr { - continue - } - noOfPredicates := rand.Intn(2) - if isJoin { - noOfPredicates++ - } + predicates := sg.createRandomExprs(0, 2, exprGenerators...) + sg.sel.AddWhere(sqlparser.AndExpressions(predicates...)) +} - for i := 0; noOfPredicates > 0; i++ { - col1 := randomEl(tables[idx1].cols) - col2 := randomEl(tables[idx2].cols) +// creates predicates for the having clause comparing a column to a random expression +func (sg *selectGenerator) createHavingPredicates(grouping []column) { + exprGenerators := slice.Map(grouping, func(c column) sqlparser.ExprGenerator { return &c }) + // add scalar subqueries + // TODO: subqueries fail + if sg.r.Intn(10) < 1 && testFailingQueries { + exprGenerators = append(exprGenerators, sg) + } - // prevent infinite loops - if i > 50 { - predicates = append(predicates, sqlparser.NewComparisonExpr(getRandomComparisonExprOperator(), newColumn(col1), newColumn(col2), nil)) - break - } + sg.genConfig = sg.genConfig.CanAggregateConfig() + predicates := sg.createRandomExprs(0, 2, exprGenerators...) + sg.genConfig = sg.genConfig.CannotAggregateConfig() - if col1.typ != col2.typ { - continue - } + sg.sel.AddHaving(sqlparser.AndExpressions(predicates...)) +} - predicates = append(predicates, sqlparser.NewComparisonExpr(getRandomComparisonExprOperator(), newColumn(col1), newColumn(col2), nil)) - noOfPredicates-- - } - } +// returns between minExprs and maxExprs random expressions using generators +func (sg *selectGenerator) createRandomExprs(minExprs, maxExprs int, generators ...sqlparser.ExprGenerator) (predicates sqlparser.Exprs) { + if minExprs > maxExprs { + log.Fatalf("minExprs is greater than maxExprs; minExprs: %d, maxExprs: %d\n", minExprs, maxExprs) + } else if maxExprs <= 0 { + return } - - // make sure the join predicate is never empty - if len(predicates) == 0 && isJoin { - predRandomExpr := getRandomExpr(tables) - predicates = append(predicates, predRandomExpr) + numPredicates := sg.r.Intn(maxExprs-minExprs+1) + minExprs + for i := 0; i < numPredicates; i++ { + predicates = append(predicates, sg.getRandomExpr(generators...)) } return } -// creates predicates for the having clause comparing a column to a random expression -func createHavingPredicates(tables []tableT) (havingPredicates sqlparser.Exprs) { - aggrSelectExprs, _ := createAggregations(tables, 2) - for i := range aggrSelectExprs { - if lhs, ok := aggrSelectExprs[i].(*sqlparser.AliasedExpr); ok { - // TODO: HAVING can only contain aggregate or grouping columns in mysql, works fine in vitess - // TODO: Can fix this by putting only the table with the grouping and aggregates column (newTable) - // TODO: but random expressions without the columns also fails - if testFailingQueries { - predRandomExpr := getRandomExpr(tables) - havingPredicates = append(havingPredicates, sqlparser.NewComparisonExpr(getRandomComparisonExprOperator(), lhs.Expr, predRandomExpr, nil)) - } else if rhs, ok1 := randomEl(aggrSelectExprs).(*sqlparser.AliasedExpr); ok1 { - havingPredicates = append(havingPredicates, sqlparser.NewComparisonExpr(getRandomComparisonExprOperator(), lhs.Expr, rhs.Expr, nil)) - } - } +// getRandomExpr returns a random expression +func (sg *selectGenerator) getRandomExpr(generators ...sqlparser.ExprGenerator) sqlparser.Expr { + var g *sqlparser.Generator + if generators == nil { + g = sqlparser.NewGenerator(sg.r, 2) + } else { + g = sqlparser.NewGenerator(sg.r, 2, generators...) } - return + + return g.Expression(sg.genConfig.SingleRowConfig().SetNumCols(1)) } // creates sel.Limit -func createLimit() *sqlparser.Limit { - limitNum := rand.Intn(10) - if rand.Intn(2) < 1 { - offset := rand.Intn(10) - return sqlparser.NewLimit(offset, limitNum) +func (sg *selectGenerator) createLimit() { + if sg.genConfig.SingleRow { + sg.sel.Limit = sqlparser.NewLimitWithoutOffset(1) + return } - return sqlparser.NewLimitWithoutOffset(limitNum) + limitNum := sg.r.Intn(10) + if sg.r.Intn(2) < 1 { + offset := sg.r.Intn(10) + sg.sel.Limit = sqlparser.NewLimit(offset, limitNum) + } else { + sg.sel.Limit = sqlparser.NewLimitWithoutOffset(limitNum) + } } -// returns a random expression and its type -func getRandomExpr(tables []tableT) sqlparser.Expr { - seed := time.Now().UnixNano() - g := sqlparser.NewGenerator(seed, 2, slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...) - return g.Expression() +// randomlyAlias randomly aliases expr with alias alias, adds it to sel.SelectExprs, and returns the column created +func (sg *selectGenerator) randomlyAlias(expr sqlparser.Expr, alias string) column { + var col column + if sg.r.Intn(2) < 1 { + alias = "" + col.name = sqlparser.String(expr) + } else { + col.name = alias + } + sg.sel.SelectExprs = append(sg.sel.SelectExprs, sqlparser.NewAliasedExpr(expr, alias)) + + return col } -func newAliasedTable(tbl tableT, alias string) *sqlparser.AliasedTableExpr { - return sqlparser.NewAliasedTableExpr(tbl.name, alias) +// matchNumCols makes sure sg.sel.SelectExprs and newTable both have the same number of cols: sg.genConfig.NumCols +func (sg *selectGenerator) matchNumCols(tables []tableT, newTable tableT, canAggregate bool) tableT { + // remove SelectExprs and newTable.cols randomly until there are sg.genConfig.NumCols amount + for len(sg.sel.SelectExprs) > sg.genConfig.NumCols && sg.genConfig.NumCols > 0 { + // select a random index and remove it from SelectExprs and newTable + idx := sg.r.Intn(len(sg.sel.SelectExprs)) + + sg.sel.SelectExprs[idx] = sg.sel.SelectExprs[len(sg.sel.SelectExprs)-1] + sg.sel.SelectExprs = sg.sel.SelectExprs[:len(sg.sel.SelectExprs)-1] + + newTable.cols[idx] = newTable.cols[len(newTable.cols)-1] + newTable.cols = newTable.cols[:len(newTable.cols)-1] + } + + // alternatively, add random expressions until there are sg.genConfig.NumCols amount + if sg.genConfig.NumCols > len(sg.sel.SelectExprs) { + diff := sg.genConfig.NumCols - len(sg.sel.SelectExprs) + exprs := sg.createRandomExprs(diff, diff, + slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...) + + for i, expr := range exprs { + col := sg.randomlyAlias(expr, fmt.Sprintf("crandom%d", i+1)) + newTable.addColumns(col) + + if canAggregate { + sg.sel.AddGroupBy(expr) + } + } + } + + return newTable } -func newAliasedColumn(col column, alias string) *sqlparser.AliasedExpr { - return sqlparser.NewAliasedExpr(newColumn(col), alias) +func getRandomOrderDirection(r *rand.Rand) sqlparser.OrderDirection { + // asc, desc + return randomEl(r, []sqlparser.OrderDirection{0, 1}) } -func newColumn(col column) *sqlparser.ColName { - return sqlparser.NewColNameWithQualifier(col.name, sqlparser.NewTableName(col.tableName)) +func getRandomJoinType(r *rand.Rand) sqlparser.JoinType { + // normal, straight, left, right, natural, natural left, natural right + return randomEl(r, []sqlparser.JoinType{0, 1, 2, 3, 4, 5, 6}) } -func getRandomComparisonExprOperator() sqlparser.ComparisonExprOperator { - // =, <, >, <=, >=, !=, <=> - return randomEl([]sqlparser.ComparisonExprOperator{0, 1, 2, 3, 4, 5, 6}) +func randomEl[K any](r *rand.Rand, in []K) K { + return in[r.Intn(len(in))] } -func getRandomOrderDirection() sqlparser.OrderDirection { - // asc, desc - return randomEl([]sqlparser.OrderDirection{0, 1}) +func newAliasedTable(tbl tableT, alias string) *sqlparser.AliasedTableExpr { + return sqlparser.NewAliasedTableExpr(tbl.tableExpr, alias) } -func randomEl[K any](in []K) K { - return in[rand.Intn(len(in))] +func newAliasedColumn(col column, alias string) *sqlparser.AliasedExpr { + return sqlparser.NewAliasedExpr(col.getASTExpr(), alias) } diff --git a/go/test/endtoend/vtgate/queries/random/query_gen_test.go b/go/test/endtoend/vtgate/queries/random/query_gen_test.go new file mode 100644 index 00000000000..fe8aa6f6492 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/random/query_gen_test.go @@ -0,0 +1,62 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package random + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" +) + +// TestSeed makes sure that the seed is deterministic +func TestSeed(t *testing.T) { + // specify the schema (that is defined in schema.sql) + schemaTables := []tableT{ + {tableExpr: sqlparser.NewTableName("emp")}, + {tableExpr: sqlparser.NewTableName("dept")}, + } + schemaTables[0].addColumns([]column{ + {name: "empno", typ: "bigint"}, + {name: "ename", typ: "varchar"}, + {name: "job", typ: "varchar"}, + {name: "mgr", typ: "bigint"}, + {name: "hiredate", typ: "date"}, + {name: "sal", typ: "bigint"}, + {name: "comm", typ: "bigint"}, + {name: "deptno", typ: "bigint"}, + }...) + schemaTables[1].addColumns([]column{ + {name: "deptno", typ: "bigint"}, + {name: "dname", typ: "varchar"}, + {name: "loc", typ: "varchar"}, + }...) + + seed := int64(1689757943775102000) + genConfig := sqlparser.NewExprGeneratorConfig(sqlparser.CannotAggregate, "", 0, false) + qg := newQueryGenerator(rand.New(rand.NewSource(seed)), genConfig, 2, 2, 2, schemaTables) + qg.randomQuery() + query1 := sqlparser.String(qg.stmt) + qg = newQueryGenerator(rand.New(rand.NewSource(seed)), genConfig, 2, 2, 2, schemaTables) + qg.randomQuery() + query2 := sqlparser.String(qg.stmt) + fmt.Println(query1) + require.Equal(t, query1, query2) +} diff --git a/go/test/endtoend/vtgate/queries/random/random_expr_test.go b/go/test/endtoend/vtgate/queries/random/random_expr_test.go index 1bf4fb7025c..450169a8d9f 100644 --- a/go/test/endtoend/vtgate/queries/random/random_expr_test.go +++ b/go/test/endtoend/vtgate/queries/random/random_expr_test.go @@ -17,6 +17,7 @@ limitations under the License. package random import ( + "math/rand" "testing" "time" @@ -24,11 +25,12 @@ import ( "vitess.io/vitess/go/vt/sqlparser" ) -// This test tests that generating a random expression with a schema does not panic +// This test tests that generating random expressions with a schema does not panic func TestRandomExprWithTables(t *testing.T) { + // specify the schema (that is defined in schema.sql) schemaTables := []tableT{ - {name: sqlparser.NewTableName("emp")}, - {name: sqlparser.NewTableName("dept")}, + {tableExpr: sqlparser.NewTableName("emp")}, + {tableExpr: sqlparser.NewTableName("dept")}, } schemaTables[0].addColumns([]column{ {name: "empno", typ: "bigint"}, @@ -46,7 +48,12 @@ func TestRandomExprWithTables(t *testing.T) { {name: "loc", typ: "varchar"}, }...) - seed := time.Now().UnixNano() - g := sqlparser.NewGenerator(seed, 3, slice.Map(schemaTables, func(t tableT) sqlparser.ExprGenerator { return &t })...) - g.Expression() + for i := 0; i < 100; i++ { + + seed := time.Now().UnixNano() + r := rand.New(rand.NewSource(seed)) + genConfig := sqlparser.NewExprGeneratorConfig(sqlparser.CanAggregate, "", 0, false) + g := sqlparser.NewGenerator(r, 3, slice.Map(schemaTables, func(t tableT) sqlparser.ExprGenerator { return &t })...) + g.Expression(genConfig) + } } diff --git a/go/test/endtoend/vtgate/queries/random/random_test.go b/go/test/endtoend/vtgate/queries/random/random_test.go index f2d9fcc0050..7b0ab93c165 100644 --- a/go/test/endtoend/vtgate/queries/random/random_test.go +++ b/go/test/endtoend/vtgate/queries/random/random_test.go @@ -18,6 +18,7 @@ package random import ( "fmt" + "math/rand" "strings" "testing" "time" @@ -33,8 +34,8 @@ import ( // this test uses the AST defined in the sqlparser package to randomly generate queries -// if true then execution will always stop on a "must fix" error: a mismatched results or EOF -const stopOnMustFixError = true +// if true then execution will always stop on a "must fix" error: a results mismatched or EOF +const stopOnMustFixError = false func start(t *testing.T) (utils.MySQLCompare, func()) { mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams) @@ -43,7 +44,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll := func() { _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") - tables := []string{"dept", "emp"} + tables := []string{"emp", "dept"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -83,52 +84,73 @@ func TestMustFix(t *testing.T) { require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "emp", clusterInstance.VtgateProcess.ReadVSchema)) require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "dept", clusterInstance.VtgateProcess.ReadVSchema)) - // mismatched results - // sum values returned as int64 instead of decimal - helperTest(t, "select /*vt+ PLANNER=Gen4 */ sum(tbl1.sal) as caggr1 from emp as tbl0, emp as tbl1 group by tbl1.ename order by tbl1.ename asc") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct case count(*) when 0 then -0 end from emp as tbl0, emp as tbl1 where 0") - // mismatched results + // results mismatched (maybe derived tables) + helperTest(t, "select /*vt+ PLANNER=Gen4 */ 0 as crandom0 from dept as tbl0, (select /*vt+ PLANNER=Gen4 */ distinct count(*) from emp as tbl1 where 0) as tbl1") + + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct case count(distinct true) when 'b' then 't' end from emp as tbl1 where 's'") + + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct sum(distinct tbl1.deptno) from dept as tbl0, emp as tbl1") + + // mismatched number of columns + helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) + 0 from emp as tbl0 order by count(*) desc") + + // results mismatched (mismatched types) + helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(0 >> 0), sum(distinct tbl2.empno) from emp as tbl0 left join emp as tbl2 on -32") + + // results mismatched (decimals off by a little; evalengine problem) + helperTest(t, "select /*vt+ PLANNER=Gen4 */ sum(case false when true then tbl1.deptno else -154 / 132 end) as caggr1 from emp as tbl0, dept as tbl1") + + // EOF + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl1.dname as cgroup0, tbl1.dname as cgroup1, tbl1.deptno as crandom0 from dept as tbl0, dept as tbl1 group by tbl1.dname, tbl1.deptno order by tbl1.deptno desc") + + // results mismatched // limit >= 9 works - helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl0.ename as cgroup1 from emp as tbl0 group by tbl0.job, tbl0.ename having sum(tbl0.mgr) = sum(tbl0.mgr) order by tbl0.job desc, tbl0.ename asc limit 8") + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl0.ename as cgroup1 from emp as tbl0 group by tbl0.job, tbl0.ename having sum(tbl0.mgr) order by tbl0.job desc, tbl0.ename asc limit 8") - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct count(*) as caggr1 from dept as tbl0, emp as tbl1 group by tbl1.sal having max(tbl1.comm) != true") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct count(*) as caggr1 from emp as tbl1 group by tbl1.sal having max(0) != true") - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct sum(tbl1.loc) as caggr0 from dept as tbl0, dept as tbl1 group by tbl1.deptno having max(tbl1.dname) <= 1") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct 0 as caggr0 from dept as tbl0, dept as tbl1 group by tbl1.deptno having max(0) <= 0") - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct max(tbl0.dname) as caggr0, 'cattle' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno != tbl1.sal group by tbl1.comm") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ min(0) as caggr0 from dept as tbl0, emp as tbl1 where case when false then tbl0.dname end group by tbl1.comm") - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 1 = 0") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 0 as crandom0 from dept as tbl0, emp as tbl1 where 0") - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 'octopus'") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 0 as crandom0 from dept as tbl0, emp as tbl1 where 'o'") // similar to previous two - // mismatched results - helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct 'octopus' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.empno having count(*) = count(*)") - - // mismatched results - // previously failing, then succeeding query, now failing again - helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(tbl0.deptno) from dept as tbl0, emp as tbl1 group by tbl1.job order by tbl1.job limit 3") + // results mismatched + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct 'o' as crandom0 from dept as tbl0, emp as tbl1 where 0 having count(*) = count(*)") - // mismatched results (group by + right join) + // results mismatched (group by + right join) // left instead of right works // swapping tables and predicates and changing to left fails - helperTest(t, "select /*vt+ PLANNER=Gen4 */ max(tbl0.deptno) from dept as tbl0 right join emp as tbl1 on tbl0.deptno = tbl1.empno and tbl0.deptno = tbl1.deptno group by tbl0.deptno") + helperTest(t, "select /*vt+ PLANNER=Gen4 */ 0 from dept as tbl0 right join emp as tbl1 on tbl0.deptno = tbl1.empno and tbl0.deptno = tbl1.deptno group by tbl0.deptno") - // mismatched results (count + right join) + // results mismatched (count + right join) // left instead of right works // swapping tables and predicates and changing to left fails helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(tbl1.comm) from emp as tbl1 right join emp as tbl2 on tbl1.mgr = tbl2.sal") + // Passes with different errors + // vitess error: EOF + // mysql error: Operand should contain 1 column(s) + helperTest(t, "select /*vt+ PLANNER=Gen4 */ 8 < -31 xor (-29, sum((tbl0.deptno, 'wren', 'ostrich')), max(distinct (tbl0.dname, -15, -8))) in ((sum(distinct (tbl0.dname, 'bengal', -10)), 'ant', true)) as caggr0 from dept as tbl0 where tbl0.deptno * (77 - 61)") + // EOF - helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) from dept as tbl0, (select count(*) from emp as tbl0, emp as tbl1 limit 18) as tbl1") + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl1.deptno as cgroup0, tbl1.loc as cgroup1, count(distinct tbl1.loc) as caggr1, tbl1.loc as crandom0 from dept as tbl0, dept as tbl1 group by tbl1.deptno, tbl1.loc") // EOF - helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*), count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0") + helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) from dept as tbl0, (select count(*) from emp as tbl0, emp as tbl1 limit 18) as tbl1") } func TestKnownFailures(t *testing.T) { @@ -140,7 +162,32 @@ func TestKnownFailures(t *testing.T) { // logs more stuff //clusterInstance.EnableGeneralLog() - // cannot compare strings, collation is unknown or unsupported (collation ID: 0) + // column 'tbl1.`not exists (select 1 from dual)`' not found + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl1.`not exists (select 1 from dual)`, count(*) from dept as tbl0, (select /*vt+ PLANNER=Gen4 */ not exists (select 1 from dual) from dept as tbl0 where tbl0.dname) as tbl1 group by tbl0.deptno, tbl1.`not exists (select 1 from dual)`") + + // VT13001: [BUG] failed to find the corresponding column + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl1.dname as cgroup0, tbl1.dname as cgroup1 from dept as tbl0, dept as tbl1 group by tbl1.dname, tbl1.deptno order by tbl1.deptno desc") + + // vitess error: + // mysql error: Operand should contain 1 column(s) + helperTest(t, "select (count('sheepdog') ^ (-71 % sum(emp.mgr) ^ count('koi')) and count(*), 'fly') from emp, dept") + + // rhs of an In operation should be a tuple + helperTest(t, "select /*vt+ PLANNER=Gen4 */ (case when true then min(distinct tbl1.job) else 'bee' end, 'molly') not in (('dane', 0)) as caggr1 from emp as tbl0, emp as tbl1") + + // VT13001: [BUG] in scatter query: complex ORDER BY expression: :vtg1 /* VARCHAR */ + helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl1.job as cgroup0, sum(distinct 'mudfish'), tbl1.job as crandom0 from emp as tbl0, emp as tbl1 group by tbl1.job order by tbl1.job asc limit 8, 1") + + // VT13001: [BUG] column should not be pushed to projection while doing a column lookup + helperTest(t, "select /*vt+ PLANNER=Gen4 */ -26 in (tbl2.mgr, -8, tbl0.deptno) as crandom0 from dept as tbl0, emp as tbl1 left join emp as tbl2 on tbl2.ename") + + // unsupported: min/max on types that are not comparable is not supported + helperTest(t, "select /*vt+ PLANNER=Gen4 */ max(case true when false then 'gnu' when true then 'meerkat' end) as caggr0 from dept as tbl0") + + // vttablet: rpc error: code = InvalidArgument desc = BIGINT UNSIGNED value is out of range in '(-(273) + (-(15) & 124))' + helperTest(t, "select /*vt+ PLANNER=Gen4 */ -273 + (-15 & 124) as crandom0 from emp as tbl0, emp as tbl1 where tbl1.sal >= tbl1.mgr") + + // vitess error: cannot compare strings, collation is unknown or unsupported (collation ID: 0) helperTest(t, "select /*vt+ PLANNER=Gen4 */ max(tbl1.dname) as caggr1 from dept as tbl0, dept as tbl1 group by tbl1.dname order by tbl1.dname asc") // vitess error: @@ -154,13 +201,17 @@ func TestKnownFailures(t *testing.T) { // coercion should not try to coerce this value: DATE("1980-12-17") helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct tbl1.hiredate as cgroup0, count(tbl1.mgr) as caggr0 from emp as tbl1 group by tbl1.hiredate, tbl1.ename") - // only_full_group_by enabled (vitess sometimes (?) produces the correct result assuming only_full_group_by is disabled) - // vitess error: nil + // only_full_group_by enabled + // vitess error: In aggregated query without GROUP BY, expression #1 of SELECT list contains nonaggregated column 'ks_random.tbl0.EMPNO'; this is incompatible with sql_mode=only_full_group_by + helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct tbl0.empno as cgroup0, count(distinct 56) as caggr0, min('flounder' = 'penguin') as caggr1 from emp as tbl0, (select /*vt+ PLANNER=Gen4 */ 'manatee' as crandom0 from dept as tbl0 where -26 limit 2) as tbl2 where 'anteater' like 'catfish' is null and -11 group by tbl0.empno order by tbl0.empno asc, count(distinct 56) asc, min('flounder' = 'penguin') desc") + + // only_full_group_by enabled + // vitess error: // mysql error: In aggregated query without GROUP BY, expression #1 of SELECT list contains nonaggregated column 'ks_random.tbl0.ENAME'; this is incompatible with sql_mode=only_full_group_by helperTest(t, "select /*vt+ PLANNER=Gen4 */ tbl0.ename, min(tbl0.comm) from emp as tbl0 left join emp as tbl1 on tbl0.empno = tbl1.comm and tbl0.empno = tbl1.empno") // only_full_group_by enabled - // vitess error: nil + // vitess error: // mysql error: Expression #1 of ORDER BY clause is not in SELECT list, references column 'ks_random.tbl2.DNAME' which is not in SELECT list; this is incompatible with DISTINCT helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct count(*) as caggr0 from dept as tbl2 group by tbl2.dname order by tbl2.dname asc") @@ -173,31 +224,46 @@ func TestKnownFailures(t *testing.T) { // vttablet: rpc error: code = InvalidArgument desc = Can't group on 'count(*)' (errno 1056) (sqlstate 42000) (CallerID: userData1) helperTest(t, "select /*vt+ PLANNER=Gen4 */ distinct count(*) from dept as tbl0 group by tbl0.deptno") - // [BUG] push projection does not yet support: *planbuilder.memorySort (errno 1815) (sqlstate HY000) + // unsupported + // VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct 1) as caggr1 + helperTest(t, "select /*vt+ PLANNER=Gen4 */ sum(distinct tbl0.comm) as caggr0, sum(distinct 1) as caggr1 from emp as tbl0 having 'redfish' < 'blowfish'") + + // unsupported + // VT12001: unsupported: aggregation on top of aggregation not supported helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) from dept as tbl1 join (select count(*) from emp as tbl0, dept as tbl1 group by tbl1.loc) as tbl2") // unsupported - // unsupported: in scatter query: complex aggregate expression (errno 1235) (sqlstate 42000) + // VT12001: unsupported: in scatter query: complex aggregate expression helperTest(t, "select /*vt+ PLANNER=Gen4 */ (select count(*) from emp as tbl0) from emp as tbl0") // unsupported - // unsupported: using aggregation on top of a *planbuilder.filter plan + // VT12001: unsupported: using aggregation on top of a *planbuilder.filter plan helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(tbl1.dname) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.dname > tbl1.loc where tbl1.loc <=> tbl1.dname group by tbl1.dname order by tbl1.dname asc") // unsupported - // unsupported: using aggregation on top of a *planbuilder.orderedAggregate plan + // VT12001: unsupported: aggregation on top of aggregation not supported helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*) from (select count(*) from dept as tbl0) as tbl0") // unsupported - // unsupported: using aggregation on top of a *planbuilder.orderedAggregate plan + // VT12001: unsupported: aggregation on top of aggregation not supported helperTest(t, "select /*vt+ PLANNER=Gen4 */ count(*), count(*) from (select count(*) from dept as tbl0) as tbl0, dept as tbl1") // unsupported - // unsupported: in scatter query: aggregation function 'avg(tbl0.deptno)' + // VT12001: unsupported: in scatter query: aggregation function 'avg(tbl0.deptno)' helperTest(t, "select /*vt+ PLANNER=Gen4 */ avg(tbl0.deptno) from dept as tbl0") + + // unsupported + // VT12001: unsupported: LEFT JOIN with derived tables + helperTest(t, "select /*vt+ PLANNER=Gen4 */ -1 as crandom0 from emp as tbl2 left join (select count(*) from dept as tbl1) as tbl3 on 6 != tbl2.deptno") + + // unsupported + // VT12001: unsupported: subqueries in GROUP BY + helperTest(t, "select /*vt+ PLANNER=Gen4 */ exists (select 1) as crandom0 from dept as tbl0 group by exists (select 1)") } func TestRandom(t *testing.T) { + t.Skip("Skip CI; random expressions generate too many failures to properly limit") + mcmp, closer := start(t) defer closer() @@ -206,8 +272,8 @@ func TestRandom(t *testing.T) { // specify the schema (that is defined in schema.sql) schemaTables := []tableT{ - {name: sqlparser.NewTableName("emp")}, - {name: sqlparser.NewTableName("dept")}, + {tableExpr: sqlparser.NewTableName("emp")}, + {tableExpr: sqlparser.NewTableName("dept")}, } schemaTables[0].addColumns([]column{ {name: "empno", typ: "bigint"}, @@ -227,23 +293,31 @@ func TestRandom(t *testing.T) { endBy := time.Now().Add(1 * time.Second) - var queryCount int - for time.Now().Before(endBy) && (!t.Failed() || testFailingQueries) { - query := sqlparser.String(randomQuery(schemaTables, 3, 3)) + var queryCount, queryFailCount int + // continue testing after an error if and only if testFailingQueries is true + for time.Now().Before(endBy) && (!t.Failed() || !testFailingQueries) { + seed := time.Now().UnixNano() + genConfig := sqlparser.NewExprGeneratorConfig(sqlparser.CannotAggregate, "", 0, false) + qg := newQueryGenerator(rand.New(rand.NewSource(seed)), genConfig, 2, 2, 2, schemaTables) + qg.randomQuery() + query := sqlparser.String(qg.stmt) _, vtErr := mcmp.ExecAllowAndCompareError(query) // this assumes all queries are valid mysql queries if vtErr != nil { + fmt.Printf("seed: %d\n", seed) fmt.Println(query) fmt.Println(vtErr) if stopOnMustFixError { - // EOF - if sqlError, ok := vtErr.(*sqlerror.SQLError); ok && strings.Contains(sqlError.Message, "EOF") { + // results mismatched + if strings.Contains(vtErr.Error(), "results mismatched") { + simplified := simplifyResultsMismatchedQuery(t, query) + fmt.Printf("final simplified query: %s\n", simplified) break } - // mismatched results - if strings.Contains(vtErr.Error(), "results mismatched") { + // EOF + if sqlError, ok := vtErr.(*sqlerror.SQLError); ok && strings.Contains(sqlError.Message, "EOF") { break } } @@ -251,12 +325,17 @@ func TestRandom(t *testing.T) { // restart the mysql and vitess connections in case something bad happened closer() mcmp, closer = start(t) + + fmt.Printf("\n\n\n") + queryFailCount++ } queryCount++ } fmt.Printf("Queries successfully executed: %d\n", queryCount) + fmt.Printf("Queries failed: %d\n", queryFailCount) } +// these queries were previously failing and have now been fixed func TestBuggyQueries(t *testing.T) { mcmp, closer := start(t) defer closer() @@ -264,47 +343,29 @@ func TestBuggyQueries(t *testing.T) { require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "emp", clusterInstance.VtgateProcess.ReadVSchema)) require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "dept", clusterInstance.VtgateProcess.ReadVSchema)) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(*), count(*), count(*) from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.deptno group by tbl1.empno order by tbl1.empno", - `[[INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)] [INT64(1) INT64(1) INT64(1)]]`) - //mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(tbl0.deptno) from dept as tbl0, emp as tbl1 group by tbl1.job order by tbl1.job limit 3", - // `[[INT64(8)] [INT64(16)] [INT64(12)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(*), count(*) from emp as tbl0 group by tbl0.empno order by tbl0.empno", - `[[INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)] [INT64(1) INT64(1)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ distinct count(*), tbl0.loc from dept as tbl0 group by tbl0.loc", - `[[INT64(1) VARCHAR("BOSTON")] [INT64(1) VARCHAR("CHICAGO")] [INT64(1) VARCHAR("DALLAS")] [INT64(1) VARCHAR("NEW YORK")]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ distinct count(*) from dept as tbl0 group by tbl0.loc", - `[[INT64(1)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ sum(tbl1.comm) from emp as tbl0, emp as tbl1", - `[[DECIMAL(30800)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl1.mgr, tbl1.mgr, count(*) from emp as tbl1 group by tbl1.mgr", - `[[NULL NULL INT64(1)] [INT64(7566) INT64(7566) INT64(2)] [INT64(7698) INT64(7698) INT64(5)] [INT64(7782) INT64(7782) INT64(1)] [INT64(7788) INT64(7788) INT64(1)] [INT64(7839) INT64(7839) INT64(3)] [INT64(7902) INT64(7902) INT64(1)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl1.mgr, tbl1.mgr, count(*) from emp as tbl0, emp as tbl1 group by tbl1.mgr", - `[[NULL NULL INT64(14)] [INT64(7566) INT64(7566) INT64(28)] [INT64(7698) INT64(7698) INT64(70)] [INT64(7782) INT64(7782) INT64(14)] [INT64(7788) INT64(7788) INT64(14)] [INT64(7839) INT64(7839) INT64(42)] [INT64(7902) INT64(7902) INT64(14)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(*), count(*), count(tbl0.comm) from emp as tbl0, emp as tbl1 join dept as tbl2", - `[[INT64(784) INT64(784) INT64(224)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(*), count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0, dept as tbl1", - `[[INT64(16) INT64(16)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0", - `[[INT64(4)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ min(tbl0.loc) from dept as tbl0", - `[[VARCHAR("BOSTON")]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl1.empno, max(tbl1.job) from dept as tbl0, emp as tbl1 group by tbl1.empno", - `[[INT64(7369) VARCHAR("CLERK")] [INT64(7499) VARCHAR("SALESMAN")] [INT64(7521) VARCHAR("SALESMAN")] [INT64(7566) VARCHAR("MANAGER")] [INT64(7654) VARCHAR("SALESMAN")] [INT64(7698) VARCHAR("MANAGER")] [INT64(7782) VARCHAR("MANAGER")] [INT64(7788) VARCHAR("ANALYST")] [INT64(7839) VARCHAR("PRESIDENT")] [INT64(7844) VARCHAR("SALESMAN")] [INT64(7876) VARCHAR("CLERK")] [INT64(7900) VARCHAR("CLERK")] [INT64(7902) VARCHAR("ANALYST")] [INT64(7934) VARCHAR("CLERK")]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl1.ename, max(tbl0.comm) from emp as tbl0, emp as tbl1 group by tbl1.ename", - `[[VARCHAR("ADAMS") INT64(1400)] [VARCHAR("ALLEN") INT64(1400)] [VARCHAR("BLAKE") INT64(1400)] [VARCHAR("CLARK") INT64(1400)] [VARCHAR("FORD") INT64(1400)] [VARCHAR("JAMES") INT64(1400)] [VARCHAR("JONES") INT64(1400)] [VARCHAR("KING") INT64(1400)] [VARCHAR("MARTIN") INT64(1400)] [VARCHAR("MILLER") INT64(1400)] [VARCHAR("SCOTT") INT64(1400)] [VARCHAR("SMITH") INT64(1400)] [VARCHAR("TURNER") INT64(1400)] [VARCHAR("WARD") INT64(1400)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl0.dname, tbl0.dname, min(tbl0.deptno) from dept as tbl0, dept as tbl1 group by tbl0.dname, tbl0.dname", - `[[VARCHAR("ACCOUNTING") VARCHAR("ACCOUNTING") INT64(10)] [VARCHAR("OPERATIONS") VARCHAR("OPERATIONS") INT64(40)] [VARCHAR("RESEARCH") VARCHAR("RESEARCH") INT64(20)] [VARCHAR("SALES") VARCHAR("SALES") INT64(30)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ tbl0.dname, min(tbl1.deptno) from dept as tbl0, dept as tbl1 group by tbl0.dname, tbl1.dname", - `[[VARCHAR("ACCOUNTING") INT64(10)] [VARCHAR("ACCOUNTING") INT64(40)] [VARCHAR("ACCOUNTING") INT64(20)] [VARCHAR("ACCOUNTING") INT64(30)] [VARCHAR("OPERATIONS") INT64(10)] [VARCHAR("OPERATIONS") INT64(40)] [VARCHAR("OPERATIONS") INT64(20)] [VARCHAR("OPERATIONS") INT64(30)] [VARCHAR("RESEARCH") INT64(10)] [VARCHAR("RESEARCH") INT64(40)] [VARCHAR("RESEARCH") INT64(20)] [VARCHAR("RESEARCH") INT64(30)] [VARCHAR("SALES") INT64(10)] [VARCHAR("SALES") INT64(40)] [VARCHAR("SALES") INT64(20)] [VARCHAR("SALES") INT64(30)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ max(tbl0.hiredate) from emp as tbl0", - `[[DATE("1983-01-12")]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ min(tbl0.deptno) as caggr0, count(*) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.loc = tbl1.dname", - `[[INT64(10) INT64(4)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ count(tbl1.loc) as caggr0 from dept as tbl1 left join dept as tbl2 on tbl1.loc = tbl2.loc where (tbl2.deptno)", - `[[INT64(4)]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ sum(tbl1.ename), min(tbl0.empno) from emp as tbl0, emp as tbl1 left join dept as tbl2 on tbl1.job = tbl2.loc and tbl1.comm = tbl2.deptno where ('trout') and tbl0.deptno = tbl1.comm", - `[[NULL NULL]]`) - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ distinct max(tbl0.deptno), count(tbl0.job) from emp as tbl0, dept as tbl1 left join dept as tbl2 on tbl1.dname = tbl2.loc and tbl1.dname = tbl2.loc where (tbl2.loc) and tbl0.deptno = tbl1.deptno", - `[[NULL INT64(0)]]`) + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ sum(tbl1.sal) as caggr1 from emp as tbl0, emp as tbl1 group by tbl1.ename order by tbl1.ename asc") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*), count(*), count(*) from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.deptno group by tbl1.empno order by tbl1.empno") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(tbl0.deptno) from dept as tbl0, emp as tbl1 group by tbl1.job order by tbl1.job limit 3") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*), count(*) from emp as tbl0 group by tbl0.empno order by tbl0.empno") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ distinct count(*), tbl0.loc from dept as tbl0 group by tbl0.loc") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ distinct count(*) from dept as tbl0 group by tbl0.loc") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ sum(tbl1.comm) from emp as tbl0, emp as tbl1") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl1.mgr, tbl1.mgr, count(*) from emp as tbl1 group by tbl1.mgr") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl1.mgr, tbl1.mgr, count(*) from emp as tbl0, emp as tbl1 group by tbl1.mgr") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*), count(*), count(tbl0.comm) from emp as tbl0, emp as tbl1 join dept as tbl2") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*), count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0, dept as tbl1") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ min(tbl0.loc) from dept as tbl0") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl1.empno, max(tbl1.job) from dept as tbl0, emp as tbl1 group by tbl1.empno") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl1.ename, max(tbl0.comm) from emp as tbl0, emp as tbl1 group by tbl1.ename") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl0.dname, tbl0.dname, min(tbl0.deptno) from dept as tbl0, dept as tbl1 group by tbl0.dname, tbl0.dname") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ tbl0.dname, min(tbl1.deptno) from dept as tbl0, dept as tbl1 group by tbl0.dname, tbl1.dname") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ max(tbl0.hiredate) from emp as tbl0") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ min(tbl0.deptno) as caggr0, count(*) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.loc = tbl1.dname") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(tbl1.loc) as caggr0 from dept as tbl1 left join dept as tbl2 on tbl1.loc = tbl2.loc where (tbl2.deptno)") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ sum(tbl1.ename), min(tbl0.empno) from emp as tbl0, emp as tbl1 left join dept as tbl2 on tbl1.job = tbl2.loc and tbl1.comm = tbl2.deptno where ('trout') and tbl0.deptno = tbl1.comm") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ distinct max(tbl0.deptno), count(tbl0.job) from emp as tbl0, dept as tbl1 left join dept as tbl2 on tbl1.dname = tbl2.loc and tbl1.dname = tbl2.loc where (tbl2.loc) and tbl0.deptno = tbl1.deptno") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*), count(*) from (select count(*) from dept as tbl0 group by tbl0.deptno) as tbl0") + mcmp.Exec("select /*vt+ PLANNER=Gen4 */ distinct max(tbl0.dname) as caggr0, 'cattle' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno != tbl1.sal group by tbl1.comm") } diff --git a/go/test/endtoend/vtgate/queries/random/simplifier_test.go b/go/test/endtoend/vtgate/queries/random/simplifier_test.go new file mode 100644 index 00000000000..478ee355d34 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/random/simplifier_test.go @@ -0,0 +1,116 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package random + +import ( + "fmt" + "strings" + "testing" + + "vitess.io/vitess/go/test/vschemawrapper" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder" + "vitess.io/vitess/go/vt/vtgate/simplifier" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestSimplifyResultsMismatchedQuery(t *testing.T) { + t.Skip("Skip CI") + + var queries []string + queries = append(queries, "select /*vt+ PLANNER=Gen4 */ (68 - -16) / case false when -45 then 3 when 28 then -43 else -62 end as crandom0 from dept as tbl0, (select /*vt+ PLANNER=Gen4 */ distinct not not false and count(*) from emp as tbl0, emp as tbl1 where tbl1.ename) as tbl1 limit 1", + "select /*vt+ PLANNER=Gen4 */ distinct case true when 'burro' then 'trout' else 'elf' end < case count(distinct true) when 'bobcat' then 'turkey' else 'penguin' end from dept as tbl0, emp as tbl1 where 'spider'", + "select /*vt+ PLANNER=Gen4 */ distinct sum(distinct tbl1.deptno) from dept as tbl0, emp as tbl1 where tbl0.deptno and tbl1.comm in (12, tbl0.deptno, case false when 67 then -17 when -78 then -35 end, -76 >> -68)", + "select /*vt+ PLANNER=Gen4 */ count(*) + 1 from emp as tbl0 order by count(*) desc", + "select /*vt+ PLANNER=Gen4 */ count(2 >> tbl2.mgr), sum(distinct tbl2.empno <=> 15) from emp as tbl0 left join emp as tbl2 on -32", + "select /*vt+ PLANNER=Gen4 */ sum(case false when true then tbl1.deptno else -154 / 132 end) as caggr1 from emp as tbl0, dept as tbl1", + "select /*vt+ PLANNER=Gen4 */ tbl1.dname as cgroup0, tbl1.dname as cgroup1 from dept as tbl0, dept as tbl1 group by tbl1.dname, tbl1.deptno order by tbl1.deptno desc", + "select /*vt+ PLANNER=Gen4 */ tbl0.ename as cgroup1 from emp as tbl0 group by tbl0.job, tbl0.ename having sum(tbl0.mgr) = sum(tbl0.mgr) order by tbl0.job desc, tbl0.ename asc limit 8", + "select /*vt+ PLANNER=Gen4 */ distinct count(*) as caggr1 from dept as tbl0, emp as tbl1 group by tbl1.sal having max(tbl1.comm) != true", + "select /*vt+ PLANNER=Gen4 */ distinct sum(tbl1.loc) as caggr0 from dept as tbl0, dept as tbl1 group by tbl1.deptno having max(tbl1.dname) <= 1", + "select /*vt+ PLANNER=Gen4 */ min(tbl0.deptno) as caggr0 from dept as tbl0, emp as tbl1 where case when false then tbl0.dname end group by tbl1.comm", + "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 1 = 0", + "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 'octopus'", + "select /*vt+ PLANNER=Gen4 */ distinct 'octopus' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.empno having count(*) = count(*)", + "select /*vt+ PLANNER=Gen4 */ max(tbl0.deptno) from dept as tbl0 right join emp as tbl1 on tbl0.deptno = tbl1.empno and tbl0.deptno = tbl1.deptno group by tbl0.deptno", + "select /*vt+ PLANNER=Gen4 */ count(tbl1.comm) from emp as tbl1 right join emp as tbl2 on tbl1.mgr = tbl2.sal") + + for _, query := range queries { + var simplified string + t.Run("simplification "+query, func(t *testing.T) { + simplified = simplifyResultsMismatchedQuery(t, query) + }) + + t.Run("simplified "+query, func(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + mcmp.ExecAllowAndCompareError(simplified) + }) + + fmt.Printf("final simplified query: %s\n", simplified) + } +} + +// given a query that errors with results mismatched, simplifyResultsMismatchedQuery returns a simpler version with the same error +func simplifyResultsMismatchedQuery(t *testing.T, query string) string { + t.Helper() + mcmp, closer := start(t) + defer closer() + + _, err := mcmp.ExecAllowAndCompareError(query) + if err == nil { + t.Fatalf("query (%s) does not error", query) + } else if !strings.Contains(err.Error(), "mismatched") { + t.Fatalf("query (%s) does not error with results mismatched\nError: %v", query, err) + } + + require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "emp", clusterInstance.VtgateProcess.ReadVSchema)) + require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "dept", clusterInstance.VtgateProcess.ReadVSchema)) + + formal, err := vindexes.LoadFormal("svschema.json") + require.NoError(t, err) + vSchema := vindexes.BuildVSchema(formal) + vSchemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: vSchema, + Version: planbuilder.Gen4, + } + + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + simplified := simplifier.SimplifyStatement( + stmt.(sqlparser.SelectStatement), + vSchemaWrapper.CurrentDb(), + vSchemaWrapper, + func(statement sqlparser.SelectStatement) bool { + q := sqlparser.String(statement) + _, newErr := mcmp.ExecAllowAndCompareError(q) + if newErr == nil { + return false + } else { + return strings.Contains(newErr.Error(), "mismatched") + } + }, + ) + + return sqlparser.String(simplified) +} diff --git a/go/test/endtoend/vtgate/queries/random/svschema.json b/go/test/endtoend/vtgate/queries/random/svschema.json new file mode 100644 index 00000000000..ccbbc6ed3a6 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/random/svschema.json @@ -0,0 +1,6 @@ +{ + "keyspaces": { + "ks_random": { + } + } +} \ No newline at end of file diff --git a/go/test/vschemawrapper/vschema_wrapper.go b/go/test/vschemawrapper/vschema_wrapper.go new file mode 100644 index 00000000000..e85b18ce36d --- /dev/null +++ b/go/test/vschemawrapper/vschema_wrapper.go @@ -0,0 +1,320 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vschemawrapper + +import ( + "context" + "fmt" + "strings" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/key" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/topo/topoproto" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +var _ plancontext.VSchema = (*VSchemaWrapper)(nil) + +type VSchemaWrapper struct { + V *vindexes.VSchema + Keyspace *vindexes.Keyspace + TabletType_ topodatapb.TabletType + Dest key.Destination + SysVarEnabled bool + Version plancontext.PlannerVersion + EnableViews bool + TestBuilder func(query string, vschema plancontext.VSchema, keyspace string) (*engine.Plan, error) +} + +func (vw *VSchemaWrapper) GetPrepareData(stmtName string) *vtgatepb.PrepareData { + switch stmtName { + case "prep_one_param": + return &vtgatepb.PrepareData{ + PrepareStatement: "select 1 from user where id = :v1", + ParamsCount: 1, + } + case "prep_in_param": + return &vtgatepb.PrepareData{ + PrepareStatement: "select 1 from user where id in (:v1, :v2)", + ParamsCount: 2, + } + case "prep_no_param": + return &vtgatepb.PrepareData{ + PrepareStatement: "select 1 from user", + ParamsCount: 0, + } + } + return nil +} + +func (vw *VSchemaWrapper) PlanPrepareStatement(ctx context.Context, query string) (*engine.Plan, sqlparser.Statement, error) { + plan, err := vw.TestBuilder(query, vw, vw.CurrentDb()) + if err != nil { + return nil, nil, err + } + stmt, _, err := sqlparser.Parse2(query) + if err != nil { + return nil, nil, err + } + return plan, stmt, nil +} + +func (vw *VSchemaWrapper) ClearPrepareData(string) {} + +func (vw *VSchemaWrapper) StorePrepareData(string, *vtgatepb.PrepareData) {} + +func (vw *VSchemaWrapper) GetUDV(name string) *querypb.BindVariable { + if strings.EqualFold(name, "prep_stmt") { + return sqltypes.StringBindVariable("select * from user where id in (?, ?, ?)") + } + return nil +} + +func (vw *VSchemaWrapper) IsShardRoutingEnabled() bool { + return false +} + +func (vw *VSchemaWrapper) GetVSchema() *vindexes.VSchema { + return vw.V +} + +func (vw *VSchemaWrapper) GetSrvVschema() *vschemapb.SrvVSchema { + return &vschemapb.SrvVSchema{ + Keyspaces: map[string]*vschemapb.Keyspace{ + "user": { + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{}, + Tables: map[string]*vschemapb.Table{ + "user": {}, + }, + }, + }, + } +} + +func (vw *VSchemaWrapper) ConnCollation() collations.ID { + return collations.CollationUtf8mb3ID +} + +func (vw *VSchemaWrapper) PlannerWarning(_ string) { +} + +func (vw *VSchemaWrapper) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) { + defaultFkMode := vschemapb.Keyspace_FK_UNMANAGED + if vw.V.Keyspaces[keyspace] != nil && vw.V.Keyspaces[keyspace].ForeignKeyMode != vschemapb.Keyspace_FK_DEFAULT { + return vw.V.Keyspaces[keyspace].ForeignKeyMode, nil + } + return defaultFkMode, nil +} + +func (vw *VSchemaWrapper) AllKeyspace() ([]*vindexes.Keyspace, error) { + if vw.Keyspace == nil { + return nil, vterrors.VT13001("keyspace not available") + } + return []*vindexes.Keyspace{vw.Keyspace}, nil +} + +// FindKeyspace implements the VSchema interface +func (vw *VSchemaWrapper) FindKeyspace(keyspace string) (*vindexes.Keyspace, error) { + if vw.Keyspace == nil { + return nil, vterrors.VT13001("keyspace not available") + } + if vw.Keyspace.Name == keyspace { + return vw.Keyspace, nil + } + return nil, nil +} + +func (vw *VSchemaWrapper) Planner() plancontext.PlannerVersion { + return vw.Version +} + +// SetPlannerVersion implements the ContextVSchema interface +func (vw *VSchemaWrapper) SetPlannerVersion(v plancontext.PlannerVersion) { + vw.Version = v +} + +func (vw *VSchemaWrapper) GetSemTable() *semantics.SemTable { + return nil +} + +func (vw *VSchemaWrapper) KeyspaceExists(keyspace string) bool { + if vw.Keyspace != nil { + return vw.Keyspace.Name == keyspace + } + return false +} + +func (vw *VSchemaWrapper) SysVarSetEnabled() bool { + return vw.SysVarEnabled +} + +func (vw *VSchemaWrapper) TargetDestination(qualifier string) (key.Destination, *vindexes.Keyspace, topodatapb.TabletType, error) { + var keyspaceName string + if vw.Keyspace != nil { + keyspaceName = vw.Keyspace.Name + } + if vw.Dest == nil && qualifier != "" { + keyspaceName = qualifier + } + if keyspaceName == "" { + return nil, nil, 0, vterrors.VT03007() + } + keyspace := vw.V.Keyspaces[keyspaceName] + if keyspace == nil { + return nil, nil, 0, vterrors.VT05003(keyspaceName) + } + return vw.Dest, keyspace.Keyspace, vw.TabletType_, nil + +} + +func (vw *VSchemaWrapper) TabletType() topodatapb.TabletType { + return vw.TabletType_ +} + +func (vw *VSchemaWrapper) Destination() key.Destination { + return vw.Dest +} + +func (vw *VSchemaWrapper) FindTable(tab sqlparser.TableName) (*vindexes.Table, string, topodatapb.TabletType, key.Destination, error) { + destKeyspace, destTabletType, destTarget, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) + if err != nil { + return nil, destKeyspace, destTabletType, destTarget, err + } + table, err := vw.V.FindTable(destKeyspace, tab.Name.String()) + if err != nil { + return nil, destKeyspace, destTabletType, destTarget, err + } + return table, destKeyspace, destTabletType, destTarget, nil +} + +func (vw *VSchemaWrapper) FindView(tab sqlparser.TableName) sqlparser.SelectStatement { + destKeyspace, _, _, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) + if err != nil { + return nil + } + return vw.V.FindView(destKeyspace, tab.Name.String()) +} + +func (vw *VSchemaWrapper) FindTableOrVindex(tab sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { + if tab.Qualifier.IsEmpty() && tab.Name.String() == "dual" { + ksName := vw.getActualKeyspace() + var ks *vindexes.Keyspace + if ksName == "" { + ks = vw.getfirstKeyspace() + ksName = ks.Name + } else { + ks = vw.V.Keyspaces[ksName].Keyspace + } + tbl := &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("dual"), + Keyspace: ks, + Type: vindexes.TypeReference, + } + return tbl, nil, ksName, topodatapb.TabletType_PRIMARY, nil, nil + } + destKeyspace, destTabletType, destTarget, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) + if err != nil { + return nil, nil, destKeyspace, destTabletType, destTarget, err + } + if destKeyspace == "" { + destKeyspace = vw.getActualKeyspace() + } + table, vindex, err := vw.V.FindTableOrVindex(destKeyspace, tab.Name.String(), topodatapb.TabletType_PRIMARY) + if err != nil { + return nil, nil, destKeyspace, destTabletType, destTarget, err + } + return table, vindex, destKeyspace, destTabletType, destTarget, nil +} + +func (vw *VSchemaWrapper) getfirstKeyspace() (ks *vindexes.Keyspace) { + var f string + for name, schema := range vw.V.Keyspaces { + if f == "" || f > name { + f = name + ks = schema.Keyspace + } + } + return +} + +func (vw *VSchemaWrapper) getActualKeyspace() string { + if vw.Keyspace == nil { + return "" + } + if !sqlparser.SystemSchema(vw.Keyspace.Name) { + return vw.Keyspace.Name + } + ks, err := vw.AnyKeyspace() + if err != nil { + return "" + } + return ks.Name +} + +func (vw *VSchemaWrapper) DefaultKeyspace() (*vindexes.Keyspace, error) { + return vw.V.Keyspaces["main"].Keyspace, nil +} + +func (vw *VSchemaWrapper) AnyKeyspace() (*vindexes.Keyspace, error) { + return vw.DefaultKeyspace() +} + +func (vw *VSchemaWrapper) FirstSortedKeyspace() (*vindexes.Keyspace, error) { + return vw.V.Keyspaces["main"].Keyspace, nil +} + +func (vw *VSchemaWrapper) TargetString() string { + return "targetString" +} + +func (vw *VSchemaWrapper) WarnUnshardedOnly(_ string, _ ...any) { + +} + +func (vw *VSchemaWrapper) ErrorIfShardedF(keyspace *vindexes.Keyspace, _, errFmt string, params ...any) error { + if keyspace.Sharded { + return fmt.Errorf(errFmt, params...) + } + return nil +} + +func (vw *VSchemaWrapper) CurrentDb() string { + ksName := "" + if vw.Keyspace != nil { + ksName = vw.Keyspace.Name + } + return ksName +} + +func (vw *VSchemaWrapper) FindRoutedShard(keyspace, shard string) (string, error) { + return "", nil +} + +func (vw *VSchemaWrapper) IsViewsEnabled() bool { + return vw.EnableViews +} diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 6a3a65ea583..7ca1b7e92e3 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -601,31 +601,6 @@ func (node *ColName) Equal(c *ColName) bool { return node.Name.Equal(c.Name) && node.Qualifier == c.Qualifier } -// Aggregates is a map of all aggregate functions. -var Aggregates = map[string]bool{ - "avg": true, - "bit_and": true, - "bit_or": true, - "bit_xor": true, - "count": true, - "group_concat": true, - "max": true, - "min": true, - "std": true, - "stddev_pop": true, - "stddev_samp": true, - "stddev": true, - "sum": true, - "var_pop": true, - "var_samp": true, - "variance": true, -} - -// IsAggregate returns true if the function is an aggregate. -func (node *FuncExpr) IsAggregate() bool { - return Aggregates[node.Name.Lowered()] -} - // NewIdentifierCI makes a new IdentifierCI. func NewIdentifierCI(str string) IdentifierCI { return IdentifierCI{ @@ -666,6 +641,19 @@ func NewTableNameWithQualifier(name, qualifier string) TableName { } } +// NewSubquery makes a new Subquery +func NewSubquery(selectStatement SelectStatement) *Subquery { + return &Subquery{Select: selectStatement} +} + +// NewDerivedTable makes a new DerivedTable +func NewDerivedTable(lateral bool, selectStatement SelectStatement) *DerivedTable { + return &DerivedTable{ + Lateral: lateral, + Select: selectStatement, + } +} + // NewAliasedTableExpr makes a new AliasedTableExpr with an alias func NewAliasedTableExpr(simpleTableExpr SimpleTableExpr, alias string) *AliasedTableExpr { return &AliasedTableExpr{ @@ -700,6 +688,10 @@ func NewAliasedExpr(expr Expr, alias string) *AliasedExpr { } } +func (ae *AliasedExpr) SetAlias(alias string) { + ae.As = NewIdentifierCI(alias) +} + // NewOrder makes a new Order func NewOrder(expr Expr, direction OrderDirection) *Order { return &Order{ @@ -708,6 +700,11 @@ func NewOrder(expr Expr, direction OrderDirection) *Order { } } +// NewNotExpr makes a new NotExpr +func NewNotExpr(expr Expr) *NotExpr { + return &NotExpr{Expr: expr} +} + // NewComparisonExpr makes a new ComparisonExpr func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr) *ComparisonExpr { return &ComparisonExpr{ @@ -718,6 +715,11 @@ func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr } } +// NewExistsExpr makes a new ExistsExpr +func NewExistsExpr(subquery *Subquery) *ExistsExpr { + return &ExistsExpr{Subquery: subquery} +} + // NewCaseExpr makes a new CaseExpr func NewCaseExpr(expr Expr, whens []*When, elseExpr Expr) *CaseExpr { return &CaseExpr{ @@ -752,14 +754,6 @@ func NewLimitWithoutOffset(rowCount int) *Limit { } } -// NewDerivedTable makes a new DerivedTable -func NewDerivedTable(lateral bool, selectStatement SelectStatement) *DerivedTable { - return &DerivedTable{ - Lateral: lateral, - Select: selectStatement, - } -} - // NewSelect is used to create a select statement func NewSelect(comments Comments, exprs SelectExprs, selectOptions []string, into *SelectInto, from TableExprs, where *Where, groupBy GroupBy, having *Where, windows NamedWindows) *Select { var cache *bool diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 3f0fb850857..fe350bedf5e 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -362,23 +362,6 @@ func TestWhere(t *testing.T) { } } -func TestIsAggregate(t *testing.T) { - f := FuncExpr{Name: NewIdentifierCI("avg")} - if !f.IsAggregate() { - t.Error("IsAggregate: false, want true") - } - - f = FuncExpr{Name: NewIdentifierCI("Avg")} - if !f.IsAggregate() { - t.Error("IsAggregate: false, want true") - } - - f = FuncExpr{Name: NewIdentifierCI("foo")} - if f.IsAggregate() { - t.Error("IsAggregate: true, want false") - } -} - func TestIsImpossible(t *testing.T) { f := ComparisonExpr{ Operator: NotEqualOp, diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index ae39e4a7722..bd9e3c8b057 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -44,6 +44,8 @@ var ( partialDDL bool ignoreNormalizerTest bool }{{ + input: "select * from foo limit 5 + 5", + }, { input: "create table x(location GEOMETRYCOLLECTION DEFAULT (POINT(7.0, 3.0)))", output: "create table x (\n\tlocation GEOMETRYCOLLECTION default (point(7.0, 3.0))\n)", }, { diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index cb8c1f23805..ebab6bbd698 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -18,6 +18,7 @@ package sqlparser import ( "fmt" + "math/rand" "testing" "time" @@ -215,8 +216,9 @@ func TestRandom(t *testing.T) { // The purpose of this test is to find discrepancies between Format and parsing. If for example our precedence rules are not consistent between the two, this test should find it. // The idea is to generate random queries, and pass them through the parser and then the unparser, and one more time. The result of the first unparse should be the same as the second result. seed := time.Now().UnixNano() + r := rand.New(rand.NewSource(seed)) fmt.Println(fmt.Sprintf("seed is %d", seed)) // nolint - g := NewGenerator(seed, 5) + g := NewGenerator(r, 5) endBy := time.Now().Add(1 * time.Second) for { @@ -224,7 +226,7 @@ func TestRandom(t *testing.T) { break } // Given a random expression - randomExpr := g.Expression() + randomExpr := g.Expression(ExprGeneratorConfig{}) inputQ := "select " + String(randomExpr) + " from t" // When it's parsed and unparsed diff --git a/go/vt/sqlparser/random_expr.go b/go/vt/sqlparser/random_expr.go index 9b3b711c87f..6eed8145ed2 100644 --- a/go/vt/sqlparser/random_expr.go +++ b/go/vt/sqlparser/random_expr.go @@ -23,29 +23,110 @@ import ( // This file is used to generate random expressions to be used for testing +// Constants for Enum Type - AggregateRule +const ( + CannotAggregate AggregateRule = iota + CanAggregate + IsAggregate +) + type ( ExprGenerator interface { - IntExpr() Expr - StringExpr() Expr + Generate(r *rand.Rand, config ExprGeneratorConfig) Expr + } + + QueryGenerator interface { + IsQueryGenerator() + ExprGenerator + } + + AggregateRule int8 + + ExprGeneratorConfig struct { + // AggrRule determines if the random expression can, cannot, or must be an aggregation expression + AggrRule AggregateRule + Type string + // MaxCols = 0 indicates no limit + NumCols int + // SingleRow indicates that the query must have at most one row + SingleRow bool + } + + Generator struct { + r *rand.Rand + depth int + maxDepth int + isAggregate bool + exprGenerators []ExprGenerator } ) -func NewGenerator(seed int64, maxDepth int, exprGenerators ...ExprGenerator) *Generator { - g := Generator{ - seed: seed, - r: rand.New(rand.NewSource(seed)), - maxDepth: maxDepth, - exprGenerator: exprGenerators, +func NewExprGeneratorConfig(aggrRule AggregateRule, typ string, numCols int, singleRow bool) ExprGeneratorConfig { + return ExprGeneratorConfig{ + AggrRule: aggrRule, + Type: typ, + NumCols: numCols, + SingleRow: singleRow, } - return &g } -type Generator struct { - seed int64 - r *rand.Rand - depth int - maxDepth int - exprGenerator []ExprGenerator +func (egc ExprGeneratorConfig) SingleRowConfig() ExprGeneratorConfig { + egc.SingleRow = true + return egc +} + +func (egc ExprGeneratorConfig) MultiRowConfig() ExprGeneratorConfig { + egc.SingleRow = false + return egc +} + +func (egc ExprGeneratorConfig) SetNumCols(numCols int) ExprGeneratorConfig { + egc.NumCols = numCols + return egc +} + +func (egc ExprGeneratorConfig) boolTypeConfig() ExprGeneratorConfig { + egc.Type = "tinyint" + return egc +} + +func (egc ExprGeneratorConfig) intTypeConfig() ExprGeneratorConfig { + egc.Type = "bigint" + return egc +} + +func (egc ExprGeneratorConfig) stringTypeConfig() ExprGeneratorConfig { + egc.Type = "varchar" + return egc +} + +func (egc ExprGeneratorConfig) anyTypeConfig() ExprGeneratorConfig { + egc.Type = "" + return egc +} + +func (egc ExprGeneratorConfig) CannotAggregateConfig() ExprGeneratorConfig { + egc.AggrRule = CannotAggregate + return egc +} + +func (egc ExprGeneratorConfig) CanAggregateConfig() ExprGeneratorConfig { + egc.AggrRule = CanAggregate + return egc +} + +func (egc ExprGeneratorConfig) IsAggregateConfig() ExprGeneratorConfig { + egc.AggrRule = IsAggregate + return egc +} + +func NewGenerator(r *rand.Rand, maxDepth int, exprGenerators ...ExprGenerator) *Generator { + g := Generator{ + r: r, + maxDepth: maxDepth, + exprGenerators: exprGenerators, + } + return &g } // enter should be called whenever we are producing an intermediate node. it should be followed by a `defer g.exit()` @@ -69,6 +150,7 @@ func (g *Generator) atMaxDepth() bool { - AND/OR/NOT - string literals, numeric literals (-/+ 1000) - columns of types bigint and varchar + - scalar and tuple subqueries - =, >, <, >=, <=, <=>, != - &, |, ^, +, -, *, /, div, %, <<, >> - IN, BETWEEN and CASE @@ -78,76 +160,223 @@ func (g *Generator) atMaxDepth() bool { Note: It's important to update this method so that it produces all expressions that need precedence checking. It's currently missing function calls and string operators */ -func (g *Generator) Expression() Expr { - if g.randomBool() { - return g.booleanExpr() +func (g *Generator) Expression(genConfig ExprGeneratorConfig) Expr { + var options []exprF + // this will only be used for tuple expressions, everything else will need genConfig.NumCols = 1 + numCols := genConfig.NumCols + genConfig = genConfig.SetNumCols(1) + + switch genConfig.Type { + case "bigint": + options = append(options, func() Expr { return g.intExpr(genConfig) }) + case "varchar": + options = append(options, func() Expr { return g.stringExpr(genConfig) }) + case "tinyint": + options = append(options, func() Expr { return g.booleanExpr(genConfig) }) + case "": + options = append(options, []exprF{ + func() Expr { return g.intExpr(genConfig) }, + func() Expr { return g.stringExpr(genConfig) }, + func() Expr { return g.booleanExpr(genConfig) }, + }...) + } + + for i := range g.exprGenerators { + generator := g.exprGenerators[i] + if generator == nil { + continue + } + + // don't create expressions from the expression exprGenerators if we haven't created an aggregation yet + if _, ok := generator.(QueryGenerator); ok || genConfig.AggrRule != IsAggregate { + options = append(options, func() Expr { + expr := generator.Generate(g.r, genConfig) + if expr == nil { + return g.randomLiteral() + } + return expr + }) + } + } + + if genConfig.AggrRule != CannotAggregate { + options = append(options, func() Expr { + g.isAggregate = true + return g.randomAggregate(genConfig.CannotAggregateConfig()) + }) + } + + // if an arbitrary number of columns may be generated, randomly choose 1-3 columns + if numCols == 0 { + numCols = g.r.Intn(3) + 1 + } + + if numCols == 1 { + return g.makeAggregateIfNecessary(genConfig, g.randomOf(options)) + } + + // with 1/5 probability choose a tuple subquery + if g.randomBool(0.2) { + return g.subqueryExpr(genConfig.SetNumCols(numCols)) + } + + tuple := ValTuple{} + for i := 0; i < numCols; i++ { + tuple = append(tuple, g.makeAggregateIfNecessary(genConfig, g.randomOf(options))) } + return tuple +} + +// makeAggregateIfNecessary is a failsafe to make sure an IsAggregate expression is in fact an aggregation +func (g *Generator) makeAggregateIfNecessary(genConfig ExprGeneratorConfig, expr Expr) Expr { + // if the generated expression must be an aggregate, and it is not, + // tack on an extra "and count(*)" to make it aggregate + if genConfig.AggrRule == IsAggregate && !g.isAggregate && g.depth == 0 { + expr = &AndExpr{ + Left: expr, + Right: &CountStar{}, + } + g.isAggregate = true + } + + return expr +} + +func (g *Generator) randomAggregate(genConfig ExprGeneratorConfig) Expr { + isDistinct := g.r.Intn(10) < 1 + options := []exprF{ - func() Expr { return g.intExpr() }, - func() Expr { return g.stringExpr() }, - func() Expr { return g.booleanExpr() }, + func() Expr { return &CountStar{} }, + func() Expr { return &Count{Args: Exprs{g.Expression(genConfig.anyTypeConfig())}, Distinct: isDistinct} }, + func() Expr { return &Sum{Arg: g.Expression(genConfig), Distinct: isDistinct} }, + func() Expr { return &Min{Arg: g.Expression(genConfig), Distinct: isDistinct} }, + func() Expr { return &Max{Arg: g.Expression(genConfig), Distinct: isDistinct} }, } + g.isAggregate = true return g.randomOf(options) } -func (g *Generator) booleanExpr() Expr { +func (g *Generator) booleanExpr(genConfig ExprGeneratorConfig) Expr { if g.atMaxDepth() { return g.booleanLiteral() } + genConfig = genConfig.boolTypeConfig() + options := []exprF{ - func() Expr { return g.andExpr() }, - func() Expr { return g.xorExpr() }, - func() Expr { return g.orExpr() }, - func() Expr { return g.comparison(g.intExpr) }, - func() Expr { return g.comparison(g.stringExpr) }, - //func() Expr { return g.comparison(g.booleanExpr) }, // this is not accepted by the parser - func() Expr { return g.inExpr() }, - func() Expr { return g.between() }, - func() Expr { return g.isExpr() }, - func() Expr { return g.notExpr() }, - func() Expr { return g.likeExpr() }, + func() Expr { return g.andExpr(genConfig) }, + func() Expr { return g.xorExpr(genConfig) }, + func() Expr { return g.orExpr(genConfig) }, + func() Expr { return g.comparison(genConfig.intTypeConfig()) }, + func() Expr { return g.comparison(genConfig.stringTypeConfig()) }, + //func() Expr { return g.comparison(genConfig) }, // this is not accepted by the parser + func() Expr { return g.inExpr(genConfig) }, + func() Expr { return g.existsExpr(genConfig) }, + func() Expr { return g.between(genConfig.intTypeConfig()) }, + func() Expr { return g.isExpr(genConfig) }, + func() Expr { return g.notExpr(genConfig) }, + func() Expr { return g.likeExpr(genConfig.stringTypeConfig()) }, } return g.randomOf(options) } -func (g *Generator) intExpr() Expr { +func (g *Generator) intExpr(genConfig ExprGeneratorConfig) Expr { if g.atMaxDepth() { return g.intLiteral() } + genConfig = genConfig.intTypeConfig() + options := []exprF{ - func() Expr { return g.arithmetic() }, - func() Expr { return g.intLiteral() }, - func() Expr { return g.caseExpr(g.intExpr) }, + g.intLiteral, + func() Expr { return g.arithmetic(genConfig) }, + func() Expr { return g.caseExpr(genConfig) }, } - for _, generator := range g.exprGenerator { - options = append(options, func() Expr { - expr := generator.IntExpr() - if expr == nil { - return g.intLiteral() - } - return expr - }) + return g.randomOf(options) +} + +func (g *Generator) stringExpr(genConfig ExprGeneratorConfig) Expr { + if g.atMaxDepth() { + return g.stringLiteral() + } + + genConfig = genConfig.stringTypeConfig() + + options := []exprF{ + g.stringLiteral, + func() Expr { return g.caseExpr(genConfig) }, + } + + return g.randomOf(options) +} + +func (g *Generator) subqueryExpr(genConfig ExprGeneratorConfig) Expr { + if g.atMaxDepth() { + return g.makeAggregateIfNecessary(genConfig, g.randomTupleLiteral(genConfig)) + } + + var options []exprF + + for _, generator := range g.exprGenerators { + if qg, ok := generator.(QueryGenerator); ok { + options = append(options, func() Expr { + expr := qg.Generate(g.r, genConfig) + if expr == nil { + return g.randomTupleLiteral(genConfig) + } + return expr + }) + } + } + + if len(options) == 0 { + return g.Expression(genConfig) + } + + return g.randomOf(options) +} + +func (g *Generator) randomTupleLiteral(genConfig ExprGeneratorConfig) Expr { + if genConfig.NumCols == 0 { + genConfig.NumCols = g.r.Intn(3) + 1 + } + + tuple := ValTuple{} + for i := 0; i < genConfig.NumCols; i++ { + tuple = append(tuple, g.randomLiteral()) + } + + return tuple +} + +func (g *Generator) randomLiteral() Expr { + options := []exprF{ + g.intLiteral, + g.stringLiteral, + g.booleanLiteral, } return g.randomOf(options) } func (g *Generator) booleanLiteral() Expr { - return BoolVal(g.randomBool()) + return BoolVal(g.randomBool(0.5)) } -func (g *Generator) randomBool() bool { - return g.r.Float32() < 0.5 +// randomBool returns true with probability prob +func (g *Generator) randomBool(prob float32) bool { + if prob < 0 || prob > 1 { + prob = 0.5 + } + return g.r.Float32() < prob } func (g *Generator) intLiteral() Expr { - t := fmt.Sprintf("%d", g.r.Intn(1000)-g.r.Intn(1000)) + t := fmt.Sprintf("%d", g.r.Intn(100)-g.r.Intn(100)) return NewIntLiteral(t) } @@ -158,77 +387,57 @@ func (g *Generator) stringLiteral() Expr { return NewStrLiteral(g.randomOfS(words)) } -func (g *Generator) stringExpr() Expr { - if g.atMaxDepth() { - return g.stringLiteral() - } - - options := []exprF{ - func() Expr { return g.stringLiteral() }, - func() Expr { return g.caseExpr(g.stringExpr) }, - } - - for _, generator := range g.exprGenerator { - options = append(options, func() Expr { - expr := generator.StringExpr() - if expr == nil { - return g.stringLiteral() - } - return expr - }) - } - - return g.randomOf(options) -} - -func (g *Generator) likeExpr() Expr { +func (g *Generator) likeExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() return &ComparisonExpr{ Operator: LikeOp, - Left: g.stringExpr(), - Right: g.stringExpr(), + Left: g.Expression(genConfig), + Right: g.Expression(genConfig), } } var comparisonOps = []ComparisonExprOperator{EqualOp, LessThanOp, GreaterThanOp, LessEqualOp, GreaterEqualOp, NotEqualOp, NullSafeEqualOp} -func (g *Generator) comparison(f func() Expr) Expr { +func (g *Generator) comparison(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() + // specifc 1-3 columns + numCols := g.r.Intn(3) + 1 + cmp := &ComparisonExpr{ Operator: comparisonOps[g.r.Intn(len(comparisonOps))], - Left: f(), - Right: f(), + Left: g.Expression(genConfig.SetNumCols(numCols)), + Right: g.Expression(genConfig.SetNumCols(numCols)), } return cmp } -func (g *Generator) caseExpr(valueF func() Expr) Expr { +func (g *Generator) caseExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() var exp Expr var elseExpr Expr - if g.randomBool() { - exp = valueF() + if g.randomBool(0.5) { + exp = g.Expression(genConfig.anyTypeConfig()) } - if g.randomBool() { - elseExpr = valueF() + if g.randomBool(0.5) { + elseExpr = g.Expression(genConfig) } - size := g.r.Intn(5) + 2 + size := g.r.Intn(2) + 1 var whens []*When for i := 0; i < size; i++ { var cond Expr if exp == nil { - cond = g.booleanExpr() + cond = g.Expression(genConfig.boolTypeConfig()) } else { - cond = g.Expression() + cond = g.Expression(genConfig) } - val := g.Expression() + val := g.Expression(genConfig) whens = append(whens, &When{ Cond: cond, Val: val, @@ -244,7 +453,7 @@ func (g *Generator) caseExpr(valueF func() Expr) Expr { var arithmeticOps = []BinaryExprOperator{BitAndOp, BitOrOp, BitXorOp, PlusOp, MinusOp, MultOp, DivOp, IntDivOp, ModOp, ShiftRightOp, ShiftLeftOp} -func (g *Generator) arithmetic() Expr { +func (g *Generator) arithmetic(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() @@ -252,8 +461,8 @@ func (g *Generator) arithmetic() Expr { return &BinaryExpr{ Operator: op, - Left: g.intExpr(), - Right: g.intExpr(), + Left: g.Expression(genConfig), + Right: g.Expression(genConfig), } } @@ -267,67 +476,66 @@ func (g *Generator) randomOfS(options []string) string { return options[g.r.Intn(len(options))] } -func (g *Generator) andExpr() Expr { +func (g *Generator) andExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() return &AndExpr{ - Left: g.booleanExpr(), - Right: g.booleanExpr(), + Left: g.Expression(genConfig), + Right: g.Expression(genConfig), } } -func (g *Generator) orExpr() Expr { +func (g *Generator) orExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() return &OrExpr{ - Left: g.booleanExpr(), - Right: g.booleanExpr(), + Left: g.Expression(genConfig), + Right: g.Expression(genConfig), } } -func (g *Generator) xorExpr() Expr { +func (g *Generator) xorExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() return &XorExpr{ - Left: g.booleanExpr(), - Right: g.booleanExpr(), + Left: g.Expression(genConfig), + Right: g.Expression(genConfig), } } -func (g *Generator) notExpr() Expr { +func (g *Generator) notExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() - return &NotExpr{g.booleanExpr()} + return &NotExpr{g.Expression(genConfig)} } -func (g *Generator) inExpr() Expr { +func (g *Generator) inExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() - expr := g.intExpr() - size := g.r.Intn(5) + 2 - tuples := ValTuple{} - for i := 0; i < size; i++ { - tuples = append(tuples, g.intExpr()) - } + size := g.r.Intn(3) + 2 + inExprGenConfig := NewExprGeneratorConfig(genConfig.AggrRule, "", size, true) + tuple1 := g.Expression(inExprGenConfig) + tuple2 := ValTuple{g.Expression(inExprGenConfig)} + op := InOp - if g.randomBool() { + if g.randomBool(0.5) { op = NotInOp } return &ComparisonExpr{ Operator: op, - Left: expr, - Right: tuples, + Left: tuple1, + Right: tuple2, } } -func (g *Generator) between() Expr { +func (g *Generator) between(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() var IsBetween bool - if g.randomBool() { + if g.randomBool(0.5) { IsBetween = true } else { IsBetween = false @@ -335,13 +543,13 @@ func (g *Generator) between() Expr { return &BetweenExpr{ IsBetween: IsBetween, - Left: g.intExpr(), - From: g.intExpr(), - To: g.intExpr(), + Left: g.Expression(genConfig), + From: g.Expression(genConfig), + To: g.Expression(genConfig), } } -func (g *Generator) isExpr() Expr { +func (g *Generator) isExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() @@ -349,6 +557,26 @@ func (g *Generator) isExpr() Expr { return &IsExpr{ Right: ops[g.r.Intn(len(ops))], - Left: g.booleanExpr(), + Left: g.Expression(genConfig), } } + +func (g *Generator) existsExpr(genConfig ExprGeneratorConfig) Expr { + expr := g.subqueryExpr(genConfig.MultiRowConfig().SetNumCols(0)) + if subquery, ok := expr.(*Subquery); ok { + expr = NewExistsExpr(subquery) + } else { + // if g.subqueryExpr doesn't return a valid subquery, replace with + // select 1 + selectExprs := SelectExprs{NewAliasedExpr(NewIntLiteral("1"), "")} + from := TableExprs{NewAliasedTableExpr(NewTableName("dual"), "")} + expr = NewExistsExpr(NewSubquery(NewSelect(nil, selectExprs, nil, nil, from, nil, nil, nil, nil))) + } + + // not exists + if g.randomBool(0.5) { + expr = NewNotExpr(expr) + } + + return expr +} diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 9adae1b4a81..3044e04f8b0 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -17,6 +17,7 @@ limitations under the License. package sqlparser import ( + "math/rand" "testing" "github.com/stretchr/testify/assert" @@ -25,8 +26,8 @@ import ( ) func BenchmarkVisitLargeExpression(b *testing.B) { - gen := NewGenerator(1, 5) - exp := gen.Expression() + gen := NewGenerator(rand.New(rand.NewSource(1)), 5) + exp := gen.Expression(ExprGeneratorConfig{}) depth := 0 for i := 0; i < b.N; i++ { diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index 5359235afa5..560ed2ff470 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -18,6 +18,7 @@ package sqlparser import ( "fmt" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -26,7 +27,7 @@ import ( func BenchmarkWalkLargeExpression(b *testing.B) { for i := 0; i < 10; i++ { b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { - exp := NewGenerator(int64(i*100), 5).Expression() + exp := NewGenerator(rand.New(rand.NewSource(int64(i*100))), 5).Expression(ExprGeneratorConfig{}) count := 0 for i := 0; i < b.N; i++ { err := Walk(func(node SQLNode) (kontinue bool, err error) { @@ -42,7 +43,7 @@ func BenchmarkWalkLargeExpression(b *testing.B) { func BenchmarkRewriteLargeExpression(b *testing.B) { for i := 1; i < 7; i++ { b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { - exp := NewGenerator(int64(i*100), i).Expression() + exp := NewGenerator(rand.New(rand.NewSource(int64(i*100))), i).Expression(ExprGeneratorConfig{}) count := 0 for i := 0; i < b.N; i++ { _ = Rewrite(exp, func(_ *Cursor) bool { diff --git a/go/vt/vtgate/planbuilder/collations_test.go b/go/vt/vtgate/planbuilder/collations_test.go index 8919a720744..f597f45562d 100644 --- a/go/vt/vtgate/planbuilder/collations_test.go +++ b/go/vt/vtgate/planbuilder/collations_test.go @@ -20,6 +20,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/test/vschemawrapper" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" @@ -39,21 +41,21 @@ type collationTestCase struct { } func (tc *collationTestCase) run(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", false), - sysVarEnabled: true, - version: Gen4, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", false), + SysVarEnabled: true, + Version: Gen4, } tc.addCollationsToSchema(vschemaWrapper) - plan, err := TestBuilder(tc.query, vschemaWrapper, vschemaWrapper.currentDb()) + plan, err := TestBuilder(tc.query, vschemaWrapper, vschemaWrapper.CurrentDb()) require.NoError(t, err) tc.check(t, tc.collations, plan.Instructions) } -func (tc *collationTestCase) addCollationsToSchema(vschema *vschemaWrapper) { +func (tc *collationTestCase) addCollationsToSchema(vschema *vschemawrapper.VSchemaWrapper) { for _, collation := range tc.collations { - tbl := vschema.v.Keyspaces[collation.ks].Tables[collation.table] + tbl := vschema.V.Keyspaces[collation.ks].Tables[collation.table] for i, c := range tbl.Columns { if c.Name.EqualString(collation.colName) { tbl.Columns[i].CollationName = collation.collationName diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index fe8bb11b72c..0265a65619e 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -28,23 +28,19 @@ import ( "strings" "testing" + "vitess.io/vitess/go/test/vschemawrapper" + "github.com/nsf/jsondiff" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/key" - querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vschemapb "vitess.io/vitess/go/vt/proto/vschema" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/sidecardb" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/topo/memorytopo" - "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" oprewriters "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -59,10 +55,11 @@ func makeTestOutput(t *testing.T) string { } func TestPlan(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - tabletType: topodatapb.TabletType_PRIMARY, - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + TabletType_: topodatapb.TabletType_PRIMARY, + SysVarEnabled: true, + TestBuilder: TestBuilder, } testOutputTempDir := makeTestOutput(t) @@ -102,10 +99,10 @@ func TestPlan(t *testing.T) { // TestForeignKeyPlanning tests the planning of foreign keys in a managed mode by Vitess. func TestForeignKeyPlanning(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), // Set the keyspace with foreign keys enabled as the default. - keyspace: &vindexes.Keyspace{ + Keyspace: &vindexes.Keyspace{ Name: "user_fk_allow", Sharded: true, }, @@ -122,24 +119,24 @@ func TestSystemTables57(t *testing.T) { defer func() { servenv.SetMySQLServerVersionForTest(oldVer) }() - vschemaWrapper := &vschemaWrapper{v: loadSchema(t, "vschemas/schema.json", true)} + vschemaWrapper := &vschemawrapper.VSchemaWrapper{V: loadSchema(t, "vschemas/schema.json", true)} testOutputTempDir := makeTestOutput(t) testFile(t, "info_schema57_cases.json", testOutputTempDir, vschemaWrapper, false) } func TestSysVarSetDisabled(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - sysVarEnabled: false, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + SysVarEnabled: false, } testFile(t, "set_sysvar_disabled_cases.json", makeTestOutput(t), vschemaWrapper, false) } func TestViews(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - enableViews: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + EnableViews: true, } testFile(t, "view_cases.json", makeTestOutput(t), vschemaWrapper, false) @@ -149,25 +146,25 @@ func TestOne(t *testing.T) { reset := oprewriters.EnableDebugPrinting() defer reset() - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), } testFile(t, "onecase.json", "", vschema, false) } func TestOneTPCC(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/tpcc_schema.json", true), + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/tpcc_schema.json", true), } testFile(t, "onecase.json", "", vschema, false) } func TestOneWithMainAsDefault(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, }, @@ -177,9 +174,9 @@ func TestOneWithMainAsDefault(t *testing.T) { } func TestOneWithSecondUserAsDefault(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "second_user", Sharded: true, }, @@ -189,9 +186,9 @@ func TestOneWithSecondUserAsDefault(t *testing.T) { } func TestOneWithUserAsDefault(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "user", Sharded: true, }, @@ -201,9 +198,9 @@ func TestOneWithUserAsDefault(t *testing.T) { } func TestOneWithTPCHVSchema(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/tpch_schema.json", true), - sysVarEnabled: true, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/tpch_schema.json", true), + SysVarEnabled: true, } testFile(t, "onecase.json", "", vschema, false) @@ -216,42 +213,42 @@ func TestOneWith57Version(t *testing.T) { defer func() { servenv.SetMySQLServerVersionForTest(oldVer) }() - vschema := &vschemaWrapper{v: loadSchema(t, "vschemas/schema.json", true)} + vschema := &vschemawrapper.VSchemaWrapper{V: loadSchema(t, "vschemas/schema.json", true)} testFile(t, "onecase.json", "", vschema, false) } func TestRubyOnRailsQueries(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/rails_schema.json", true), - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/rails_schema.json", true), + SysVarEnabled: true, } testFile(t, "rails_cases.json", makeTestOutput(t), vschemaWrapper, false) } func TestOLTP(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/oltp_schema.json", true), - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/oltp_schema.json", true), + SysVarEnabled: true, } testFile(t, "oltp_cases.json", makeTestOutput(t), vschemaWrapper, false) } func TestTPCC(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/tpcc_schema.json", true), - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/tpcc_schema.json", true), + SysVarEnabled: true, } testFile(t, "tpcc_cases.json", makeTestOutput(t), vschemaWrapper, false) } func TestTPCH(t *testing.T) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "vschemas/tpch_schema.json", true), - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/tpch_schema.json", true), + SysVarEnabled: true, } testFile(t, "tpch_cases.json", makeTestOutput(t), vschemaWrapper, false) @@ -270,9 +267,9 @@ func BenchmarkTPCH(b *testing.B) { } func benchmarkWorkload(b *testing.B, name string) { - vschemaWrapper := &vschemaWrapper{ - v: loadSchema(b, "vschemas/"+name+"_schema.json", true), - sysVarEnabled: true, + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(b, "vschemas/"+name+"_schema.json", true), + SysVarEnabled: true, } testCases := readJSONTests(name + "_cases.json") @@ -285,14 +282,14 @@ func benchmarkWorkload(b *testing.B, name string) { } func TestBypassPlanningShardTargetFromFile(t *testing.T) { - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, }, - tabletType: topodatapb.TabletType_PRIMARY, - dest: key.DestinationShard("-80")} + TabletType_: topodatapb.TabletType_PRIMARY, + Dest: key.DestinationShard("-80")} testFile(t, "bypass_shard_cases.json", makeTestOutput(t), vschema, false) } @@ -300,14 +297,14 @@ func TestBypassPlanningShardTargetFromFile(t *testing.T) { func TestBypassPlanningKeyrangeTargetFromFile(t *testing.T) { keyRange, _ := key.ParseShardingSpec("-") - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, }, - tabletType: topodatapb.TabletType_PRIMARY, - dest: key.DestinationExactKeyRange{KeyRange: keyRange[0]}, + TabletType_: topodatapb.TabletType_PRIMARY, + Dest: key.DestinationExactKeyRange{KeyRange: keyRange[0]}, } testFile(t, "bypass_keyrange_cases.json", makeTestOutput(t), vschema, false) @@ -315,13 +312,13 @@ func TestBypassPlanningKeyrangeTargetFromFile(t *testing.T) { func TestWithDefaultKeyspaceFromFile(t *testing.T) { // We are testing this separately so we can set a default keyspace - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, }, - tabletType: topodatapb.TabletType_PRIMARY, + TabletType_: topodatapb.TabletType_PRIMARY, } ts := memorytopo.NewServer("cell1") ts.CreateKeyspace(context.Background(), "main", &topodatapb.Keyspace{}) @@ -348,13 +345,13 @@ func TestWithDefaultKeyspaceFromFile(t *testing.T) { func TestWithDefaultKeyspaceFromFileSharded(t *testing.T) { // We are testing this separately so we can set a default keyspace - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "second_user", Sharded: true, }, - tabletType: topodatapb.TabletType_PRIMARY, + TabletType_: topodatapb.TabletType_PRIMARY, } testOutputTempDir := makeTestOutput(t) @@ -363,13 +360,13 @@ func TestWithDefaultKeyspaceFromFileSharded(t *testing.T) { func TestWithUserDefaultKeyspaceFromFileSharded(t *testing.T) { // We are testing this separately so we can set a default keyspace - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "user", Sharded: true, }, - tabletType: topodatapb.TabletType_PRIMARY, + TabletType_: topodatapb.TabletType_PRIMARY, } testOutputTempDir := makeTestOutput(t) @@ -378,10 +375,10 @@ func TestWithUserDefaultKeyspaceFromFileSharded(t *testing.T) { func TestWithSystemSchemaAsDefaultKeyspace(t *testing.T) { // We are testing this separately so we can set a default keyspace - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{Name: "information_schema"}, - tabletType: topodatapb.TabletType_PRIMARY, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{Name: "information_schema"}, + TabletType_: topodatapb.TabletType_PRIMARY, } testFile(t, "sysschema_default.json", makeTestOutput(t), vschema, false) @@ -389,13 +386,13 @@ func TestWithSystemSchemaAsDefaultKeyspace(t *testing.T) { func TestOtherPlanningFromFile(t *testing.T) { // We are testing this separately so we can set a default keyspace - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - keyspace: &vindexes.Keyspace{ + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, }, - tabletType: topodatapb.TabletType_PRIMARY, + TabletType_: topodatapb.TabletType_PRIMARY, } testOutputTempDir := makeTestOutput(t) @@ -472,288 +469,6 @@ func createFkDefinition(childCols []string, parentTableName string, parentCols [ } } -var _ plancontext.VSchema = (*vschemaWrapper)(nil) - -type vschemaWrapper struct { - v *vindexes.VSchema - keyspace *vindexes.Keyspace - tabletType topodatapb.TabletType - dest key.Destination - sysVarEnabled bool - version plancontext.PlannerVersion - enableViews bool -} - -func (vw *vschemaWrapper) GetPrepareData(stmtName string) *vtgatepb.PrepareData { - switch stmtName { - case "prep_one_param": - return &vtgatepb.PrepareData{ - PrepareStatement: "select 1 from user where id = :v1", - ParamsCount: 1, - } - case "prep_in_param": - return &vtgatepb.PrepareData{ - PrepareStatement: "select 1 from user where id in (:v1, :v2)", - ParamsCount: 2, - } - case "prep_no_param": - return &vtgatepb.PrepareData{ - PrepareStatement: "select 1 from user", - ParamsCount: 0, - } - } - return nil -} - -func (vw *vschemaWrapper) PlanPrepareStatement(ctx context.Context, query string) (*engine.Plan, sqlparser.Statement, error) { - plan, err := TestBuilder(query, vw, vw.currentDb()) - if err != nil { - return nil, nil, err - } - stmt, _, err := sqlparser.Parse2(query) - if err != nil { - return nil, nil, err - } - return plan, stmt, nil -} - -func (vw *vschemaWrapper) ClearPrepareData(lowered string) { -} - -func (vw *vschemaWrapper) StorePrepareData(string, *vtgatepb.PrepareData) {} - -func (vw *vschemaWrapper) GetUDV(name string) *querypb.BindVariable { - if strings.EqualFold(name, "prep_stmt") { - return sqltypes.StringBindVariable("select * from user where id in (?, ?, ?)") - } - return nil -} - -func (vw *vschemaWrapper) IsShardRoutingEnabled() bool { - return false -} - -func (vw *vschemaWrapper) GetVSchema() *vindexes.VSchema { - return vw.v -} - -func (vw *vschemaWrapper) GetSrvVschema() *vschemapb.SrvVSchema { - return &vschemapb.SrvVSchema{ - Keyspaces: map[string]*vschemapb.Keyspace{ - "user": { - Sharded: true, - Vindexes: map[string]*vschemapb.Vindex{}, - Tables: map[string]*vschemapb.Table{ - "user": {}, - }, - }, - }, - } -} - -func (vw *vschemaWrapper) ConnCollation() collations.ID { - return collations.Default() -} - -func (vw *vschemaWrapper) PlannerWarning(_ string) { -} - -func (vw *vschemaWrapper) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) { - defaultFkMode := vschemapb.Keyspace_FK_UNMANAGED - if vw.v.Keyspaces[keyspace] != nil && vw.v.Keyspaces[keyspace].ForeignKeyMode != vschemapb.Keyspace_FK_DEFAULT { - return vw.v.Keyspaces[keyspace].ForeignKeyMode, nil - } - return defaultFkMode, nil -} - -func (vw *vschemaWrapper) AllKeyspace() ([]*vindexes.Keyspace, error) { - if vw.keyspace == nil { - return nil, vterrors.VT13001("keyspace not available") - } - return []*vindexes.Keyspace{vw.keyspace}, nil -} - -// FindKeyspace implements the VSchema interface -func (vw *vschemaWrapper) FindKeyspace(keyspace string) (*vindexes.Keyspace, error) { - if vw.keyspace == nil { - return nil, vterrors.VT13001("keyspace not available") - } - if vw.keyspace.Name == keyspace { - return vw.keyspace, nil - } - return nil, nil -} - -func (vw *vschemaWrapper) Planner() plancontext.PlannerVersion { - return vw.version -} - -// SetPlannerVersion implements the ContextVSchema interface -func (vw *vschemaWrapper) SetPlannerVersion(v plancontext.PlannerVersion) { - vw.version = v -} - -func (vw *vschemaWrapper) GetSemTable() *semantics.SemTable { - return nil -} - -func (vw *vschemaWrapper) KeyspaceExists(keyspace string) bool { - if vw.keyspace != nil { - return vw.keyspace.Name == keyspace - } - return false -} - -func (vw *vschemaWrapper) SysVarSetEnabled() bool { - return vw.sysVarEnabled -} - -func (vw *vschemaWrapper) TargetDestination(qualifier string) (key.Destination, *vindexes.Keyspace, topodatapb.TabletType, error) { - var keyspaceName string - if vw.keyspace != nil { - keyspaceName = vw.keyspace.Name - } - if vw.dest == nil && qualifier != "" { - keyspaceName = qualifier - } - if keyspaceName == "" { - return nil, nil, 0, vterrors.VT03007() - } - keyspace := vw.v.Keyspaces[keyspaceName] - if keyspace == nil { - return nil, nil, 0, vterrors.VT05003(keyspaceName) - } - return vw.dest, keyspace.Keyspace, vw.tabletType, nil - -} - -func (vw *vschemaWrapper) TabletType() topodatapb.TabletType { - return vw.tabletType -} - -func (vw *vschemaWrapper) Destination() key.Destination { - return vw.dest -} - -func (vw *vschemaWrapper) FindTable(tab sqlparser.TableName) (*vindexes.Table, string, topodatapb.TabletType, key.Destination, error) { - destKeyspace, destTabletType, destTarget, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil, destKeyspace, destTabletType, destTarget, err - } - table, err := vw.v.FindTable(destKeyspace, tab.Name.String()) - if err != nil { - return nil, destKeyspace, destTabletType, destTarget, err - } - return table, destKeyspace, destTabletType, destTarget, nil -} - -func (vw *vschemaWrapper) FindView(tab sqlparser.TableName) sqlparser.SelectStatement { - destKeyspace, _, _, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil - } - return vw.v.FindView(destKeyspace, tab.Name.String()) -} - -func (vw *vschemaWrapper) FindTableOrVindex(tab sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { - if tab.Qualifier.IsEmpty() && tab.Name.String() == "dual" { - ksName := vw.getActualKeyspace() - var ks *vindexes.Keyspace - if ksName == "" { - ks = vw.getfirstKeyspace() - ksName = ks.Name - } else { - ks = vw.v.Keyspaces[ksName].Keyspace - } - tbl := &vindexes.Table{ - Name: sqlparser.NewIdentifierCS("dual"), - Keyspace: ks, - Type: vindexes.TypeReference, - } - return tbl, nil, ksName, topodatapb.TabletType_PRIMARY, nil, nil - } - destKeyspace, destTabletType, destTarget, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil, nil, destKeyspace, destTabletType, destTarget, err - } - if destKeyspace == "" { - destKeyspace = vw.getActualKeyspace() - } - table, vindex, err := vw.v.FindTableOrVindex(destKeyspace, tab.Name.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil, nil, destKeyspace, destTabletType, destTarget, err - } - return table, vindex, destKeyspace, destTabletType, destTarget, nil -} - -func (vw *vschemaWrapper) getfirstKeyspace() (ks *vindexes.Keyspace) { - var f string - for name, schema := range vw.v.Keyspaces { - if f == "" || f > name { - f = name - ks = schema.Keyspace - } - } - return -} - -func (vw *vschemaWrapper) getActualKeyspace() string { - if vw.keyspace == nil { - return "" - } - if !sqlparser.SystemSchema(vw.keyspace.Name) { - return vw.keyspace.Name - } - ks, err := vw.AnyKeyspace() - if err != nil { - return "" - } - return ks.Name -} - -func (vw *vschemaWrapper) DefaultKeyspace() (*vindexes.Keyspace, error) { - return vw.v.Keyspaces["main"].Keyspace, nil -} - -func (vw *vschemaWrapper) AnyKeyspace() (*vindexes.Keyspace, error) { - return vw.DefaultKeyspace() -} - -func (vw *vschemaWrapper) FirstSortedKeyspace() (*vindexes.Keyspace, error) { - return vw.v.Keyspaces["main"].Keyspace, nil -} - -func (vw *vschemaWrapper) TargetString() string { - return "targetString" -} - -func (vw *vschemaWrapper) WarnUnshardedOnly(_ string, _ ...any) { - -} - -func (vw *vschemaWrapper) ErrorIfShardedF(keyspace *vindexes.Keyspace, _, errFmt string, params ...any) error { - if keyspace.Sharded { - return fmt.Errorf(errFmt, params...) - } - return nil -} - -func (vw *vschemaWrapper) currentDb() string { - ksName := "" - if vw.keyspace != nil { - ksName = vw.keyspace.Name - } - return ksName -} - -func (vw *vschemaWrapper) FindRoutedShard(keyspace, shard string) (string, error) { - return "", nil -} - -func (vw *vschemaWrapper) IsViewsEnabled() bool { - return vw.enableViews -} - type ( planTest struct { Comment string `json:"comment,omitempty"` @@ -762,7 +477,7 @@ type ( } ) -func testFile(t *testing.T, filename, tempDir string, vschema *vschemaWrapper, render bool) { +func testFile(t *testing.T, filename, tempDir string, vschema *vschemawrapper.VSchemaWrapper, render bool) { opts := jsondiff.DefaultConsoleOptions() t.Run(filename, func(t *testing.T) { @@ -779,7 +494,7 @@ func testFile(t *testing.T, filename, tempDir string, vschema *vschemaWrapper, r Comment: testName, Query: tcase.Query, } - vschema.version = Gen4 + vschema.Version = Gen4 out := getPlanOutput(tcase, vschema, render) // our expectation for the planner on the query is one of three @@ -825,13 +540,13 @@ func readJSONTests(filename string) []planTest { return output } -func getPlanOutput(tcase planTest, vschema *vschemaWrapper, render bool) (out string) { +func getPlanOutput(tcase planTest, vschema *vschemawrapper.VSchemaWrapper, render bool) (out string) { defer func() { if r := recover(); r != nil { out = fmt.Sprintf("panicked: %v\n%s", r, string(debug.Stack())) } }() - plan, err := TestBuilder(tcase.Query, vschema, vschema.currentDb()) + plan, err := TestBuilder(tcase.Query, vschema, vschema.CurrentDb()) if render && plan != nil { viz, err := engine.GraphViz(plan.Instructions) if err == nil { @@ -863,9 +578,9 @@ func locateFile(name string) string { var benchMarkFiles = []string{"from_cases.json", "filter_cases.json", "large_cases.json", "aggr_cases.json", "select_cases.json", "union_cases.json"} func BenchmarkPlanner(b *testing.B) { - vschema := &vschemaWrapper{ - v: loadSchema(b, "vschemas/schema.json", true), - sysVarEnabled: true, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(b, "vschemas/schema.json", true), + SysVarEnabled: true, } for _, filename := range benchMarkFiles { testCases := readJSONTests(filename) @@ -879,15 +594,15 @@ func BenchmarkPlanner(b *testing.B) { } func BenchmarkSemAnalysis(b *testing.B) { - vschema := &vschemaWrapper{ - v: loadSchema(b, "vschemas/schema.json", true), - sysVarEnabled: true, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(b, "vschemas/schema.json", true), + SysVarEnabled: true, } for i := 0; i < b.N; i++ { for _, filename := range benchMarkFiles { for _, tc := range readJSONTests(filename) { - exerciseAnalyzer(tc.Query, vschema.currentDb(), vschema) + exerciseAnalyzer(tc.Query, vschema.CurrentDb(), vschema) } } } @@ -912,10 +627,10 @@ func exerciseAnalyzer(query, database string, s semantics.SchemaInformation) { } func BenchmarkSelectVsDML(b *testing.B) { - vschema := &vschemaWrapper{ - v: loadSchema(b, "vschemas/schema.json", true), - sysVarEnabled: true, - version: Gen4, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(b, "vschemas/schema.json", true), + SysVarEnabled: true, + Version: Gen4, } dmlCases := readJSONTests("dml_cases.json") @@ -938,13 +653,13 @@ func BenchmarkSelectVsDML(b *testing.B) { }) } -func benchmarkPlanner(b *testing.B, version plancontext.PlannerVersion, testCases []planTest, vschema *vschemaWrapper) { +func benchmarkPlanner(b *testing.B, version plancontext.PlannerVersion, testCases []planTest, vschema *vschemawrapper.VSchemaWrapper) { b.ReportAllocs() for n := 0; n < b.N; n++ { for _, tcase := range testCases { if len(tcase.Plan) > 0 { - vschema.version = version - _, _ = TestBuilder(tcase.Query, vschema, vschema.currentDb()) + vschema.Version = version + _, _ = TestBuilder(tcase.Query, vschema, vschema.CurrentDb()) } } } diff --git a/go/vt/vtgate/planbuilder/show_test.go b/go/vt/vtgate/planbuilder/show_test.go index 3caae74bf27..04f8c571764 100644 --- a/go/vt/vtgate/planbuilder/show_test.go +++ b/go/vt/vtgate/planbuilder/show_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/test/vschemawrapper" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" @@ -31,8 +33,8 @@ import ( ) func TestBuildDBPlan(t *testing.T) { - vschema := &vschemaWrapper{ - keyspace: &vindexes.Keyspace{Name: "main"}, + vschema := &vschemawrapper.VSchemaWrapper{ + Keyspace: &vindexes.Keyspace{Name: "main"}, } testCases := []struct { diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index 0854934ee7f..1e106adacc0 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/test/vschemawrapper" + "vitess.io/vitess/go/vt/vterrors" "github.com/stretchr/testify/assert" @@ -40,18 +42,18 @@ func TestSimplifyBuggyQuery(t *testing.T) { query := "select distinct count(distinct a), count(distinct 4) from user left join unsharded on 0 limit 5" // select 0 from unsharded union select 0 from `user` union select 0 from unsharded // select 0 from unsharded union (select 0 from `user` union select 0 from unsharded) - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - version: Gen4, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Version: Gen4, } stmt, reserved, err := sqlparser.Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.CloneStatement(stmt), vschema.currentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) + rewritten, _ := sqlparser.RewriteAST(sqlparser.CloneStatement(stmt), vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.currentDb(), + vschema.CurrentDb(), vschema, keepSameError(query, reservedVars, vschema, rewritten.BindVarNeeds), ) @@ -62,18 +64,18 @@ func TestSimplifyBuggyQuery(t *testing.T) { func TestSimplifyPanic(t *testing.T) { t.Skip("not needed to run") query := "(select id from unsharded union select id from unsharded_auto) union (select id from unsharded_auto union select name from unsharded)" - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - version: Gen4, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Version: Gen4, } stmt, reserved, err := sqlparser.Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.CloneStatement(stmt), vschema.currentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) + rewritten, _ := sqlparser.RewriteAST(sqlparser.CloneStatement(stmt), vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.currentDb(), + vschema.CurrentDb(), vschema, keepPanicking(query, reservedVars, vschema, rewritten.BindVarNeeds), ) @@ -83,9 +85,9 @@ func TestSimplifyPanic(t *testing.T) { func TestUnsupportedFile(t *testing.T) { t.Skip("run manually to see if any queries can be simplified") - vschema := &vschemaWrapper{ - v: loadSchema(t, "vschemas/schema.json", true), - version: Gen4, + vschema := &vschemawrapper.VSchemaWrapper{ + V: loadSchema(t, "vschemas/schema.json", true), + Version: Gen4, } fmt.Println(vschema) for _, tcase := range readJSONTests("unsupported_cases.txt") { @@ -98,11 +100,11 @@ func TestUnsupportedFile(t *testing.T) { t.Skip() return } - rewritten, err := sqlparser.RewriteAST(stmt, vschema.currentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) + rewritten, err := sqlparser.RewriteAST(stmt, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) if err != nil { t.Skip() } - vschema.currentDb() + vschema.CurrentDb() reservedVars := sqlparser.NewReservedVars("vtg", reserved) ast := rewritten.AST @@ -110,7 +112,7 @@ func TestUnsupportedFile(t *testing.T) { stmt, _, _ = sqlparser.Parse2(tcase.Query) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.currentDb(), + vschema.CurrentDb(), vschema, keepSameError(tcase.Query, reservedVars, vschema, rewritten.BindVarNeeds), ) @@ -127,12 +129,12 @@ func TestUnsupportedFile(t *testing.T) { } } -func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { +func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { stmt, _, err := sqlparser.Parse2(query) if err != nil { panic(err) } - rewritten, _ := sqlparser.RewriteAST(stmt, vschema.currentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) + rewritten, _ := sqlparser.RewriteAST(stmt, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil) ast := rewritten.AST _, expected := BuildFromStmt(context.Background(), query, ast, reservedVars, vschema, rewritten.BindVarNeeds, true, true) if expected == nil { @@ -151,7 +153,7 @@ func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema * } } -func keepPanicking(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { +func keepPanicking(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { cmp := func(statement sqlparser.SelectStatement) (res bool) { defer func() { r := recover() diff --git a/go/vt/vtgate/simplifier/simplifier.go b/go/vt/vtgate/simplifier/simplifier.go index 971224112f8..0e19935caba 100644 --- a/go/vt/vtgate/simplifier/simplifier.go +++ b/go/vt/vtgate/simplifier/simplifier.go @@ -108,6 +108,7 @@ func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.Se if test(in) { log.Errorf("removed expression: %s", sqlparser.String(cursor.expr)) simplified = true + // initially return false, but that made the rewriter prematurely abort, if it was the last selectExpr return true } cursor.restore() diff --git a/go/vt/vtgate/simplifier/simplifier_test.go b/go/vt/vtgate/simplifier/simplifier_test.go index 569f5adfab0..c9edbbab8d8 100644 --- a/go/vt/vtgate/simplifier/simplifier_test.go +++ b/go/vt/vtgate/simplifier/simplifier_test.go @@ -55,11 +55,11 @@ limit 123 offset 456 require.NoError(t, err) visitAllExpressionsInAST(ast.(sqlparser.SelectStatement), func(cursor expressionCursor) bool { fmt.Printf(">> found expression: %s\n", sqlparser.String(cursor.expr)) - cursor.replace(sqlparser.NewIntLiteral("1")) + cursor.remove() fmt.Printf("remove: %s\n", sqlparser.String(ast)) cursor.restore() fmt.Printf("restore: %s\n", sqlparser.String(ast)) - cursor.remove() + cursor.replace(sqlparser.NewIntLiteral("1")) fmt.Printf("replace it with literal: %s\n", sqlparser.String(ast)) cursor.restore() fmt.Printf("restore: %s\n", sqlparser.String(ast))