diff --git a/dm/pkg/schema/tracker.go b/dm/pkg/schema/tracker.go index 1d2e0052bdb..1d0640cc6f9 100644 --- a/dm/pkg/schema/tracker.go +++ b/dm/pkg/schema/tracker.go @@ -370,6 +370,22 @@ func (tr *Tracker) CreateTableIfNotExists(table *filter.Table, ti *model.TableIn return tr.dom.DDL().CreateTableWithInfo(tr.se, schemaName, ti, ddl.OnExistIgnore) } +func (tr *Tracker) BatchCreateTableIfNotExist(tablesToCreate map[string]map[string]*model.TableInfo) error { + for schema, tableNameInfo := range tablesToCreate { + var cloneTis []*model.TableInfo + for table, ti := range tableNameInfo { + cloneTi := cloneTableInfo(ti) // clone TableInfo w.r.t the warning of the CreateTable function + cloneTi.Name = model.NewCIStr(table) // TableInfo has no `TableName` + cloneTis = append(cloneTis, cloneTi) + } + schemaName := model.NewCIStr(schema) + if err := tr.dom.DDL().BatchCreateTableWithInfo(tr.se, schemaName, cloneTis, ddl.OnExistIgnore); err != nil { + return err + } + } + return nil +} + // GetSystemVar gets a variable from schema tracker. func (tr *Tracker) GetSystemVar(name string) (string, bool) { return tr.se.GetSessionVars().GetSystemVar(name) diff --git a/dm/pkg/schema/tracker_test.go b/dm/pkg/schema/tracker_test.go index 5e6d9f12e97..b67ac50483c 100644 --- a/dm/pkg/schema/tracker_test.go +++ b/dm/pkg/schema/tracker_test.go @@ -445,6 +445,118 @@ func (s *trackerSuite) TestCreateTableIfNotExists(c *C) { c.Assert(duration.Seconds(), Less, float64(30)) } +func (s *trackerSuite) TestBatchCreateTableIfNotExist(c *C) { + log.SetLevel(zapcore.ErrorLevel) + tracker, err := NewTracker(context.Background(), "test-tracker", defaultTestSessionCfg, s.dbConn) + c.Assert(err, IsNil) + err = tracker.CreateSchemaIfNotExists("testdb") + c.Assert(err, IsNil) + err = tracker.CreateSchemaIfNotExists("testdb2") + c.Assert(err, IsNil) + + tables := []*filter.Table{ + { + Schema: "testdb", + Name: "foo", + }, + { + Schema: "testdb", + Name: "foo1", + }, + { + Schema: "testdb2", + Name: "foo3", + }, + } + execStmt := []string{ + `create table foo( + a int primary key auto_increment, + b int as (c+1) not null, + c int comment 'some cmt', + d text, + key dk(d(255)) + ) comment 'more cmt' partition by range columns (a) ( + partition x41 values less than (41), + partition x82 values less than (82), + partition rest values less than maxvalue comment 'part cmt' + );`, + `create table foo1( + a int primary key, + b text not null, + d datetime, + e varchar(5) + );`, + `create table foo3( + a int, + b int, + primary key(a));`, + } + tiInfos := make([]*model.TableInfo, len(tables)) + for i := range tables { + ctx := context.Background() + err = tracker.Exec(ctx, tables[i].Schema, execStmt[i]) + c.Assert(err, IsNil) + tiInfos[i], err = tracker.GetTableInfo(tables[i]) + c.Assert(err, IsNil) + c.Assert(tiInfos[i], NotNil) + c.Assert(tiInfos[i].Name.O, Equals, tables[i].Name) + tiInfos[i] = tiInfos[i].Clone() + clearVolatileInfo(tiInfos[i]) + } + // drop all tables and recover + // 1. drop + for i := range tables { + err = tracker.DropTable(tables[i]) + c.Assert(err, IsNil) + _, err = tracker.GetTableInfo(tables[i]) + c.Assert(err, ErrorMatches, `.*Table 'testdb.*\.foo.*' doesn't exist`) // drop table success + } + // 2. recover + tablesToCreate := map[string]map[string]*model.TableInfo{} + tablesToCreate["testdb"] = map[string]*model.TableInfo{} + tablesToCreate["testdb2"] = map[string]*model.TableInfo{} + for i := range tables { + tablesToCreate[tables[i].Schema][tables[i].Name] = tiInfos[i] + } + err = tracker.BatchCreateTableIfNotExist(tablesToCreate) + c.Assert(err, IsNil) + // 3. check all create success + for i := range tables { + var ti *model.TableInfo + ti, err = tracker.GetTableInfo(tables[i]) + c.Assert(err, IsNil) + cloneTi := ti.Clone() + clearVolatileInfo(cloneTi) + c.Assert(cloneTi, DeepEquals, tiInfos[i]) + } + + // drop two tables and create all three + // expect: silently succeed + // 1. drop table + err = tracker.DropTable(tables[2]) + c.Assert(err, IsNil) + err = tracker.DropTable(tables[0]) + c.Assert(err, IsNil) + // 2. batch create + err = tracker.BatchCreateTableIfNotExist(tablesToCreate) + c.Assert(err, IsNil) + // 3. check + for i := range tables { + var ti *model.TableInfo + ti, err = tracker.GetTableInfo(tables[i]) + c.Assert(err, IsNil) + clearVolatileInfo(ti) + c.Assert(ti, DeepEquals, tiInfos[i]) + } + + // drop schema and raise error + ctx := context.Background() + err = tracker.Exec(ctx, "", `drop database testdb`) + c.Assert(err, IsNil) + err = tracker.BatchCreateTableIfNotExist(tablesToCreate) + c.Assert(err, NotNil) +} + func (s *trackerSuite) TestAllSchemas(c *C) { log.SetLevel(zapcore.ErrorLevel) ctx := context.Background() diff --git a/dm/syncer/checkpoint.go b/dm/syncer/checkpoint.go index 49d83a344b2..0873d607c27 100644 --- a/dm/syncer/checkpoint.go +++ b/dm/syncer/checkpoint.go @@ -844,6 +844,7 @@ func (cp *RemoteCheckPoint) Rollback(schemaTracker *schema.Tracker) { cp.RLock() defer cp.RUnlock() cp.globalPoint.rollback(schemaTracker, "") + tablesToCreate := make(map[string]map[string]*model.TableInfo) for schemaName, mSchema := range cp.points { for tableName, point := range mSchema { table := &filter.Table{ @@ -864,13 +865,18 @@ func (cp *RemoteCheckPoint) Rollback(schemaTracker *schema.Tracker) { if err := schemaTracker.CreateSchemaIfNotExists(schemaName); err != nil { logger.Error("failed to rollback schema on schema tracker: cannot create schema", log.ShortError(err)) } - if err := schemaTracker.CreateTableIfNotExists(table, point.savedPoint.ti); err != nil { - logger.Error("failed to rollback schema on schema tracker: cannot create table", log.ShortError(err)) + if _, ok := tablesToCreate[schemaName]; !ok { + tablesToCreate[schemaName] = map[string]*model.TableInfo{} } + tablesToCreate[schemaName][tableName] = point.savedPoint.ti } } } } + logger := cp.logCtx.L().WithFields(zap.Reflect("batch create table", tablesToCreate)) + if err := schemaTracker.BatchCreateTableIfNotExist(tablesToCreate); err != nil { + logger.Error("failed to rollback schema on schema tracker: cannot create table", log.ShortError(err)) + } // drop any tables in the tracker if no corresponding checkpoint exists. for _, schema := range schemaTracker.AllSchemas() {