diff --git a/go/vt/vttablet/onlineddl/analysis.go b/go/vt/vttablet/onlineddl/analysis.go index 970104877f2..1d96ab232dd 100644 --- a/go/vt/vttablet/onlineddl/analysis.go +++ b/go/vt/vttablet/onlineddl/analysis.go @@ -68,23 +68,6 @@ func (p *SpecialAlterPlan) String() string { return string(b) } -// getCreateTableStatement gets a formal AlterTable representation of the given table -func (e *Executor) getCreateTableStatement(ctx context.Context, tableName string) (*sqlparser.CreateTable, error) { - showCreateTable, err := e.showCreateTable(ctx, tableName) - if err != nil { - return nil, vterrors.Wrapf(err, "in Executor.getCreateTableStatement()") - } - stmt, err := e.env.Environment().Parser().ParseStrictDDL(showCreateTable) - if err != nil { - return nil, err - } - createTable, ok := stmt.(*sqlparser.CreateTable) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "expected CREATE TABLE. Got %v", sqlparser.CanonicalString(stmt)) - } - return createTable, nil -} - // analyzeInstantDDL takes declarative CreateTable and AlterTable, as well as a server version, and checks whether it is possible to run the ALTER // using ALGORITHM=INSTANT for that version. func analyzeInstantDDL(alterTable *sqlparser.AlterTable, createTable *sqlparser.CreateTable, capableOf capabilities.CapableOf) (*SpecialAlterPlan, error) { diff --git a/go/vt/vttablet/onlineddl/executor.go b/go/vt/vttablet/onlineddl/executor.go index 5c1e718b873..00998c453e1 100644 --- a/go/vt/vttablet/onlineddl/executor.go +++ b/go/vt/vttablet/onlineddl/executor.go @@ -604,6 +604,23 @@ func (e *Executor) showCreateTable(ctx context.Context, tableName string) (strin return row[1].ToString(), nil } +// getCreateTableStatement gets a formal AlterTable representation of the given table +func (e *Executor) getCreateTableStatement(ctx context.Context, tableName string) (*sqlparser.CreateTable, error) { + showCreateTable, err := e.showCreateTable(ctx, tableName) + if err != nil { + return nil, vterrors.Wrapf(err, "in Executor.getCreateTableStatement()") + } + stmt, err := e.env.Environment().Parser().ParseStrictDDL(showCreateTable) + if err != nil { + return nil, err + } + createTable, ok := stmt.(*sqlparser.CreateTable) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "expected CREATE TABLE. Got %v", sqlparser.CanonicalString(stmt)) + } + return createTable, nil +} + func (e *Executor) parseAlterOptions(ctx context.Context, onlineDDL *schema.OnlineDDL) string { // Temporary hack (2020-08-11) // Because sqlparser does not do full blown ALTER TABLE parsing, @@ -1387,20 +1404,11 @@ func (e *Executor) validateAndEditAlterTableStatement(capableOf capabilities.Cap // - The format CreateTable AST // - A new CreateTable AST, with the table renamed as `newTableName`, and with constraints renamed deterministically // - Map of renamed constraints -func (e *Executor) duplicateCreateTable(ctx context.Context, onlineDDL *schema.OnlineDDL, originalShowCreateTable string, newTableName string) ( - originalCreateTable *sqlparser.CreateTable, +func (e *Executor) duplicateCreateTable(ctx context.Context, onlineDDL *schema.OnlineDDL, originalCreateTable *sqlparser.CreateTable, newTableName string) ( newCreateTable *sqlparser.CreateTable, constraintMap map[string]string, err error, ) { - stmt, err := e.env.Environment().Parser().ParseStrictDDL(originalShowCreateTable) - if err != nil { - return nil, nil, nil, err - } - originalCreateTable, ok := stmt.(*sqlparser.CreateTable) - if !ok { - return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "expected CreateTable statement, got: %v", sqlparser.CanonicalString(stmt)) - } newCreateTable = sqlparser.Clone(originalCreateTable) newCreateTable.SetTable(newCreateTable.GetTable().Qualifier.CompliantName(), newTableName) @@ -1426,32 +1434,32 @@ func (e *Executor) duplicateCreateTable(ctx context.Context, onlineDDL *schema.O // unique across the schema constraintMap, err = e.validateAndEditCreateTableStatement(onlineDDL, newCreateTable) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - return originalCreateTable, newCreateTable, constraintMap, nil + return newCreateTable, constraintMap, nil } // createDuplicateTableLike creates the table named by `newTableName` in the likeness of onlineDDL.Table // This function emulates MySQL's `CREATE TABLE LIKE ...` statement. The difference is that this function takes control over the generated CONSTRAINT names, // if any, such that they are deterministic across shards, as well as preserve original names where possible. func (e *Executor) createDuplicateTableLike(ctx context.Context, newTableName string, onlineDDL *schema.OnlineDDL, conn *dbconnpool.DBConnection) ( - originalShowCreateTable string, + originalCreateTable *sqlparser.CreateTable, constraintMap map[string]string, err error, ) { - originalShowCreateTable, err = e.showCreateTable(ctx, onlineDDL.Table) + originalCreateTable, err = e.getCreateTableStatement(ctx, onlineDDL.Table) if err != nil { - return "", nil, err + return nil, nil, err } - _, vreplCreateTable, constraintMap, err := e.duplicateCreateTable(ctx, onlineDDL, originalShowCreateTable, newTableName) + vreplCreateTable, constraintMap, err := e.duplicateCreateTable(ctx, onlineDDL, originalCreateTable, newTableName) if err != nil { - return "", nil, err + return nil, nil, err } // Create the vrepl (shadow) table: if _, err := conn.ExecuteFetch(sqlparser.CanonicalString(vreplCreateTable), 0, false); err != nil { - return "", nil, err + return nil, nil, err } - return originalShowCreateTable, constraintMap, nil + return originalCreateTable, constraintMap, nil } // initVreplicationOriginalMigration performs the first steps towards running a VRepl ALTER migration: @@ -1476,7 +1484,7 @@ func (e *Executor) initVreplicationOriginalMigration(ctx context.Context, online if err := e.updateArtifacts(ctx, onlineDDL.UUID, vreplTableName); err != nil { return v, err } - originalShowCreateTable, constraintMap, err := e.createDuplicateTableLike(ctx, vreplTableName, onlineDDL, conn) + originalCreateTable, constraintMap, err := e.createDuplicateTableLike(ctx, vreplTableName, onlineDDL, conn) if err != nil { return nil, err } @@ -1505,12 +1513,12 @@ func (e *Executor) initVreplicationOriginalMigration(ctx context.Context, online } } - vreplShowCreateTable, err := e.showCreateTable(ctx, vreplTableName) + vreplCreateTable, err := e.getCreateTableStatement(ctx, vreplTableName) if err != nil { return v, err } - v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, originalShowCreateTable, vreplShowCreateTable, onlineDDL.SQL, onlineDDL.StrategySetting().IsAnalyzeTableFlag()) + v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, originalCreateTable, vreplCreateTable, alterTable, onlineDDL.StrategySetting().IsAnalyzeTableFlag()) return v, nil } @@ -1564,7 +1572,7 @@ func (e *Executor) initVreplicationRevertMigration(ctx context.Context, onlineDD if err := e.updateArtifacts(ctx, onlineDDL.UUID, vreplTableName); err != nil { return v, err } - v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, "", "", "", false) + v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, nil, nil, nil, false) v.pos = revertStream.pos return v, nil } diff --git a/go/vt/vttablet/onlineddl/executor_test.go b/go/vt/vttablet/onlineddl/executor_test.go index 81c8f4cb0f0..1dc5447bbb9 100644 --- a/go/vt/vttablet/onlineddl/executor_test.go +++ b/go/vt/vttablet/onlineddl/executor_test.go @@ -391,9 +391,13 @@ func TestDuplicateCreateTable(t *testing.T) { } for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { - originalCreateTable, newCreateTable, constraintMap, err := e.duplicateCreateTable(ctx, onlineDDL, tcase.sql, tcase.newName) + stmt, err := e.env.Environment().Parser().ParseStrictDDL(tcase.sql) + require.NoError(t, err) + originalCreateTable, ok := stmt.(*sqlparser.CreateTable) + require.True(t, ok) + require.NotNil(t, originalCreateTable) + newCreateTable, constraintMap, err := e.duplicateCreateTable(ctx, onlineDDL, originalCreateTable, tcase.newName) assert.NoError(t, err) - assert.NotNil(t, originalCreateTable) assert.NotNil(t, newCreateTable) assert.NotNil(t, constraintMap) diff --git a/go/vt/vttablet/onlineddl/vrepl.go b/go/vt/vttablet/onlineddl/vrepl.go index 847e40e3fbc..cde2f276563 100644 --- a/go/vt/vttablet/onlineddl/vrepl.go +++ b/go/vt/vttablet/onlineddl/vrepl.go @@ -105,11 +105,11 @@ type VRepl struct { sourceTable string targetTable string pos string - alterQuery string + alterQuery *sqlparser.AlterTable tableRows int64 - originalShowCreateTable string - vreplShowCreateTable string + originalCreateTable *sqlparser.CreateTable + vreplCreateTable *sqlparser.CreateTable analyzeTable bool @@ -150,27 +150,27 @@ func NewVRepl( dbName string, sourceTable string, targetTable string, - originalShowCreateTable string, - vreplShowCreateTable string, - alterQuery string, + originalCreateTable *sqlparser.CreateTable, + vreplCreateTable *sqlparser.CreateTable, + alterQuery *sqlparser.AlterTable, analyzeTable bool, ) *VRepl { return &VRepl{ - env: env, - workflow: workflow, - keyspace: keyspace, - shard: shard, - dbName: dbName, - sourceTable: sourceTable, - targetTable: targetTable, - originalShowCreateTable: originalShowCreateTable, - vreplShowCreateTable: vreplShowCreateTable, - alterQuery: alterQuery, - analyzeTable: analyzeTable, - parser: vrepl.NewAlterTableParser(), - enumToTextMap: map[string]string{}, - intToEnumMap: map[string]bool{}, - convertCharset: map[string](*binlogdatapb.CharsetConversion){}, + env: env, + workflow: workflow, + keyspace: keyspace, + shard: shard, + dbName: dbName, + sourceTable: sourceTable, + targetTable: targetTable, + originalCreateTable: originalCreateTable, + vreplCreateTable: vreplCreateTable, + alterQuery: alterQuery, + analyzeTable: analyzeTable, + parser: vrepl.NewAlterTableParser(), + enumToTextMap: map[string]string{}, + intToEnumMap: map[string]bool{}, + convertCharset: map[string](*binlogdatapb.CharsetConversion){}, } } @@ -386,15 +386,13 @@ func (v *VRepl) applyColumnTypes(ctx context.Context, conn *dbconnpool.DBConnect } func (v *VRepl) analyzeAlter(ctx context.Context) error { - if v.alterQuery == "" { + if v.alterQuery == nil { // Happens for REVERT return nil } - if err := v.parser.ParseAlterStatement(v.alterQuery, v.env.Parser()); err != nil { - return err - } + v.parser.AnalyzeAlter(v.alterQuery) if v.parser.IsRenameTable() { - return fmt.Errorf("Renaming the table is not aupported in ALTER TABLE: %s", v.alterQuery) + return fmt.Errorf("Renaming the table is not supported in ALTER TABLE: %s", sqlparser.CanonicalString(v.alterQuery)) } return nil } @@ -461,7 +459,7 @@ func (v *VRepl) analyzeTables(ctx context.Context, conn *dbconnpool.DBConnection } v.addedUniqueKeys = vrepl.AddedUniqueKeys(sourceUniqueKeys, targetUniqueKeys, v.parser.ColumnRenameMap()) v.removedUniqueKeys = vrepl.RemovedUniqueKeys(sourceUniqueKeys, targetUniqueKeys, v.parser.ColumnRenameMap()) - v.removedForeignKeyNames, err = vrepl.RemovedForeignKeyNames(v.env, v.originalShowCreateTable, v.vreplShowCreateTable) + v.removedForeignKeyNames, err = vrepl.RemovedForeignKeyNames(v.env, v.originalCreateTable, v.vreplCreateTable) if err != nil { return err } diff --git a/go/vt/vttablet/onlineddl/vrepl/foreign_key.go b/go/vt/vttablet/onlineddl/vrepl/foreign_key.go index 79e2df614f4..006beb7345c 100644 --- a/go/vt/vttablet/onlineddl/vrepl/foreign_key.go +++ b/go/vt/vttablet/onlineddl/vrepl/foreign_key.go @@ -29,17 +29,17 @@ import ( // RemovedForeignKeyNames returns the names of removed foreign keys, ignoring mere name changes func RemovedForeignKeyNames( venv *vtenv.Environment, - originalCreateTable string, - vreplCreateTable string, + originalCreateTable *sqlparser.CreateTable, + vreplCreateTable *sqlparser.CreateTable, ) (names []string, err error) { - if originalCreateTable == "" || vreplCreateTable == "" { + if originalCreateTable == nil || vreplCreateTable == nil { return nil, nil } env := schemadiff.NewEnv(venv, venv.CollationEnv().DefaultConnectionCharset()) diffHints := schemadiff.DiffHints{ ConstraintNamesStrategy: schemadiff.ConstraintNamesIgnoreAll, } - diff, err := schemadiff.DiffCreateTablesQueries(env, originalCreateTable, vreplCreateTable, &diffHints) + diff, err := schemadiff.DiffTables(env, originalCreateTable, vreplCreateTable, &diffHints) if err != nil { return nil, err } diff --git a/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go b/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go index 95b2c84e66e..66775092dcb 100644 --- a/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go +++ b/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go @@ -24,7 +24,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" ) @@ -68,7 +70,20 @@ func TestRemovedForeignKeyNames(t *testing.T) { } for _, tcase := range tcases { t.Run(tcase.before, func(t *testing.T) { - names, err := RemovedForeignKeyNames(vtenv.NewTestEnv(), tcase.before, tcase.after) + env := vtenv.NewTestEnv() + beforeStmt, err := env.Parser().ParseStrictDDL(tcase.before) + require.NoError(t, err) + beforeCreateTable, ok := beforeStmt.(*sqlparser.CreateTable) + require.True(t, ok) + require.NotNil(t, beforeCreateTable) + + afterStmt, err := env.Parser().ParseStrictDDL(tcase.after) + require.NoError(t, err) + afterCreateTable, ok := afterStmt.(*sqlparser.CreateTable) + require.True(t, ok) + require.NotNil(t, afterCreateTable) + + names, err := RemovedForeignKeyNames(env, beforeCreateTable, afterCreateTable) assert.NoError(t, err) assert.Equal(t, tcase.names, names) }) diff --git a/go/vt/vttablet/onlineddl/vrepl/parser.go b/go/vt/vttablet/onlineddl/vrepl/parser.go index b5648adeabe..f76f8735016 100644 --- a/go/vt/vttablet/onlineddl/vrepl/parser.go +++ b/go/vt/vttablet/onlineddl/vrepl/parser.go @@ -23,9 +23,7 @@ package vrepl import ( "strings" - "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vterrors" ) // AlterTableParser is a parser tool for ALTER TABLE statements @@ -48,13 +46,13 @@ func NewAlterTableParser() *AlterTableParser { // NewParserFromAlterStatement creates a new parser with a ALTER TABLE statement func NewParserFromAlterStatement(alterTable *sqlparser.AlterTable) *AlterTableParser { parser := NewAlterTableParser() - parser.analyzeAlter(alterTable) + parser.AnalyzeAlter(alterTable) return parser } -// analyzeAlter looks for specific changes in the AlterTable statement, that are relevant +// AnalyzeAlter looks for specific changes in the AlterTable statement, that are relevant // to OnlineDDL/VReplication -func (p *AlterTableParser) analyzeAlter(alterTable *sqlparser.AlterTable) { +func (p *AlterTableParser) AnalyzeAlter(alterTable *sqlparser.AlterTable) { for _, opt := range alterTable.AlterOptions { switch opt := opt.(type) { case *sqlparser.RenameTableName: @@ -77,20 +75,6 @@ func (p *AlterTableParser) analyzeAlter(alterTable *sqlparser.AlterTable) { } } -// ParseAlterStatement is the main function of th eparser, and parses an ALTER TABLE statement -func (p *AlterTableParser) ParseAlterStatement(alterQuery string, parser *sqlparser.Parser) (err error) { - stmt, err := parser.ParseStrictDDL(alterQuery) - if err != nil { - return err - } - alterTable, ok := stmt.(*sqlparser.AlterTable) - if !ok { - return vterrors.Errorf(vtrpc.Code_FAILED_PRECONDITION, "expected AlterTable statement, got %v", sqlparser.CanonicalString(stmt)) - } - p.analyzeAlter(alterTable) - return nil -} - // GetNonTrivialRenames gets a list of renamed column func (p *AlterTableParser) GetNonTrivialRenames() map[string]string { result := make(map[string]string) diff --git a/go/vt/vttablet/onlineddl/vrepl/parser_test.go b/go/vt/vttablet/onlineddl/vrepl/parser_test.go index 2a7031f3a98..93e2ef25a15 100644 --- a/go/vt/vttablet/onlineddl/vrepl/parser_test.go +++ b/go/vt/vttablet/onlineddl/vrepl/parser_test.go @@ -24,24 +24,33 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/sqlparser" ) +func alterTableStatement(t *testing.T, sql string) *sqlparser.AlterTable { + stmt, err := sqlparser.NewTestParser().ParseStrictDDL(sql) + require.NoError(t, err) + alter, ok := stmt.(*sqlparser.AlterTable) + require.True(t, ok) + return alter +} + func TestParseAlterStatement(t *testing.T) { statement := "alter table t add column t int, engine=innodb" + alterStatement := alterTableStatement(t, statement) parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + parser.AnalyzeAlter(alterStatement) assert.False(t, parser.HasNonTrivialRenames()) assert.False(t, parser.IsAutoIncrementDefined()) } func TestParseAlterStatementTrivialRename(t *testing.T) { statement := "alter table t add column t int, change ts ts timestamp, engine=innodb" + alterStatement := alterTableStatement(t, statement) parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + parser.AnalyzeAlter(alterStatement) assert.False(t, parser.HasNonTrivialRenames()) assert.False(t, parser.IsAutoIncrementDefined()) assert.Equal(t, len(parser.columnRenameMap), 1) @@ -68,17 +77,17 @@ func TestParseAlterStatementWithAutoIncrement(t *testing.T) { for _, statement := range statements { parser := NewAlterTableParser() statement := "alter table t " + statement - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + alterStatement := alterTableStatement(t, statement) + parser.AnalyzeAlter(alterStatement) assert.True(t, parser.IsAutoIncrementDefined()) } } func TestParseAlterStatementTrivialRenames(t *testing.T) { statement := "alter table t add column t int, change ts ts timestamp, CHANGE f `f` float, engine=innodb" + alterStatement := alterTableStatement(t, statement) parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + parser.AnalyzeAlter(alterStatement) assert.False(t, parser.HasNonTrivialRenames()) assert.False(t, parser.IsAutoIncrementDefined()) assert.Equal(t, len(parser.columnRenameMap), 2) @@ -99,9 +108,9 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { for _, statement := range statements { statement := "alter table t " + statement + alterStatement := alterTableStatement(t, statement) parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + parser.AnalyzeAlter(alterStatement) assert.False(t, parser.IsAutoIncrementDefined()) renames := parser.GetNonTrivialRenames() assert.Equal(t, len(renames), 2) @@ -115,16 +124,16 @@ func TestParseAlterStatementDroppedColumns(t *testing.T) { { parser := NewAlterTableParser() statement := "alter table t drop column b" - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + alterStatement := alterTableStatement(t, statement) + parser.AnalyzeAlter(alterStatement) assert.Equal(t, len(parser.droppedColumns), 1) assert.True(t, parser.droppedColumns["b"]) } { parser := NewAlterTableParser() statement := "alter table t drop column b, drop key c_idx, drop column `d`" - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + alterStatement := alterTableStatement(t, statement) + parser.AnalyzeAlter(alterStatement) assert.Equal(t, len(parser.droppedColumns), 2) assert.True(t, parser.droppedColumns["b"]) assert.True(t, parser.droppedColumns["d"]) @@ -132,19 +141,13 @@ func TestParseAlterStatementDroppedColumns(t *testing.T) { { parser := NewAlterTableParser() statement := "alter table t drop column b, drop key c_idx, drop column `d`, drop `e`, drop primary key, drop foreign key fk_1" - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.NoError(t, err) + alterStatement := alterTableStatement(t, statement) + parser.AnalyzeAlter(alterStatement) assert.Equal(t, len(parser.droppedColumns), 3) assert.True(t, parser.droppedColumns["b"]) assert.True(t, parser.droppedColumns["d"]) assert.True(t, parser.droppedColumns["e"]) } - { - parser := NewAlterTableParser() - statement := "alter table t drop column b, drop bad statement, add column i int" - err := parser.ParseAlterStatement(statement, sqlparser.NewTestParser()) - assert.Error(t, err) - } } func TestParseAlterStatementRenameTable(t *testing.T) { @@ -179,8 +182,8 @@ func TestParseAlterStatementRenameTable(t *testing.T) { for _, tc := range tt { t.Run(tc.alter, func(t *testing.T) { parser := NewAlterTableParser() - err := parser.ParseAlterStatement(tc.alter, sqlparser.NewTestParser()) - assert.NoError(t, err) + alterStatement := alterTableStatement(t, tc.alter) + parser.AnalyzeAlter(alterStatement) assert.Equal(t, tc.isRename, parser.isRenameTable) }) }