diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 95c265c07..297acf7ce 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -37,7 +37,6 @@ type Conn struct { autoCommit bool dbName string dbType types.DBType - superConn *Conn } // ResetSession is called prior to executing a query on the connection diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index 63c800be7..e444d9774 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -127,7 +127,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx.TxOpt = opts if tm.IsGlobalTx(ctx) { - c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode } @@ -145,7 +145,6 @@ func (c *ATConn) createOnceTxContext(ctx context.Context) bool { if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType - c.txCtx.XaID = tm.GetXID(ctx) c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode c.txCtx.GlobalLockRequire = true diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go index 90b0145d7..d2a9b1d93 100644 --- a/pkg/datasource/sql/conn_at_test.go +++ b/pkg/datasource/sql/conn_at_test.go @@ -85,8 +85,8 @@ func TestATConn_ExecContext(t *testing.T) { t.Logf("set xid=%s", tm.GetXID(ctx)) beforeHook := func(_ context.Context, execCtx *types.ExecContext) { - t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) - assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + t.Logf("on exec xid=%s", execCtx.TxCtx.XID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) } mi.before = beforeHook @@ -111,7 +111,7 @@ func TestATConn_ExecContext(t *testing.T) { t.Run("not xid", func(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } @@ -148,7 +148,7 @@ func TestATConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } @@ -174,7 +174,7 @@ func TestATConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } @@ -202,7 +202,7 @@ func TestATConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) } diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 017c36b32..aeaed975e 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -74,7 +74,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, if tm.IsGlobalTx(ctx) { c.txCtx.TransType = types.XAMode - c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.XID = tm.GetXID(ctx) } tx, err := c.Conn.BeginTx(ctx, opts) @@ -91,7 +91,7 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType - c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransType = types.XAMode } diff --git a/pkg/datasource/sql/conn_xa_test.go b/pkg/datasource/sql/conn_xa_test.go index f2d0085ba..4977eec05 100644 --- a/pkg/datasource/sql/conn_xa_test.go +++ b/pkg/datasource/sql/conn_xa_test.go @@ -136,8 +136,8 @@ func TestXAConn_ExecContext(t *testing.T) { t.Logf("set xid=%s", tm.GetXID(ctx)) before := func(_ context.Context, execCtx *types.ExecContext) { - t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) - assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + t.Logf("on exec xid=%s", execCtx.TxCtx.XID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType) } mi.before = before @@ -163,7 +163,7 @@ func TestXAConn_ExecContext(t *testing.T) { t.Run("not xid", func(t *testing.T) { before := func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } mi.before = before @@ -202,7 +202,7 @@ func TestXAConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } @@ -228,7 +228,7 @@ func TestXAConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, "", execCtx.TxCtx.XID) assert.Equal(t, types.Local, execCtx.TxCtx.TransType) } @@ -256,7 +256,7 @@ func TestXAConn_BeginTx(t *testing.T) { assert.NoError(t, err) mi.before = func(_ context.Context, execCtx *types.ExecContext) { - assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType) } diff --git a/pkg/datasource/sql/hook/logger_hook.go b/pkg/datasource/sql/hook/logger_hook.go index 2cc30d317..e020d3677 100644 --- a/pkg/datasource/sql/hook/logger_hook.go +++ b/pkg/datasource/sql/hook/logger_hook.go @@ -44,7 +44,7 @@ func (h *loggerSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) } fields := []zap.Field{ zap.String("tx-id", txID), - zap.String("xid", execCtx.TxCtx.XaID), + zap.String("xid", execCtx.TxCtx.XID), zap.String("sql", execCtx.Query), } diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 9de4b27be..0b01ff066 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -154,7 +154,7 @@ func (tx *Tx) register(ctx *types.TransactionContext) error { lockKey += k + ";" } request := rm.BranchRegisterParam{ - Xid: ctx.XaID, + Xid: ctx.XID, BranchType: ctx.TransType.GetBranchType(), ResourceId: ctx.ResourceID, LockKeys: lockKey, @@ -176,7 +176,7 @@ func (tx *Tx) report(success bool) error { } status := getStatus(success) request := message.BranchReportRequest{ - Xid: tx.tranCtx.XaID, + Xid: tx.tranCtx.XID, BranchId: int64(tx.tranCtx.BranchID), ResourceId: tx.tranCtx.ResourceID, Status: status, @@ -187,7 +187,7 @@ func (tx *Tx) report(success bool) error { err := dataSourceManager.BranchReport(context.Background(), request) if err != nil { retry-- - log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XaID, success, retry) + log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) if retry == 0 { log.Infof("Failed to report branch status: %s", err.Error()) return err diff --git a/pkg/datasource/sql/types/image.go b/pkg/datasource/sql/types/image.go index ef9782a6c..b275849d1 100644 --- a/pkg/datasource/sql/types/image.go +++ b/pkg/datasource/sql/types/image.go @@ -97,7 +97,7 @@ type RecordImage struct { // Rows data row Rows []RowImage `json:"rows"` // TableMeta table information schema - TableMeta TableMeta + TableMeta TableMeta `json:"-"` } // RowImage Mirror data information information diff --git a/pkg/datasource/sql/types/meta.go b/pkg/datasource/sql/types/meta.go index 96aeac356..c98eeab85 100644 --- a/pkg/datasource/sql/types/meta.go +++ b/pkg/datasource/sql/types/meta.go @@ -95,6 +95,18 @@ func (m TableMeta) IsEmpty() bool { return m.TableName == "" } +func (m TableMeta) GetPrimaryKeyMap() map[string]ColumnMeta { + pk := make(map[string]ColumnMeta) + for _, index := range m.Indexs { + if index.IType == IndexTypePrimaryKey { + for _, column := range index.Columns { + pk[column.ColumnName] = column + } + } + } + return pk +} + func (m TableMeta) GetPrimaryKeyOnlyName() []string { keys := make([]string, 0) for _, index := range m.Indexs { @@ -104,6 +116,5 @@ func (m TableMeta) GetPrimaryKeyOnlyName() []string { } } } - return keys } diff --git a/pkg/datasource/sql/types/sql.go b/pkg/datasource/sql/types/sql.go index fba21eb4a..40f60d996 100644 --- a/pkg/datasource/sql/types/sql.go +++ b/pkg/datasource/sql/types/sql.go @@ -54,3 +54,69 @@ const ( SQLTypeDropIndex SQLTypeMulti ) + +func (s SQLType) MarshalText() (text []byte, err error) { + switch s { + case SQLTypeSelect: + return []byte("SELECT"), nil + case SQLTypeInsert: + return []byte("INSERT"), nil + case SQLTypeUpdate: + return []byte("UPDATE"), nil + case SQLTypeDelete: + return []byte("DELETE"), nil + case SQLTypeSelectForUpdate: + return []byte("SELECT_FOR_UPDATE"), nil + case SQLTypeReplace: + return []byte("REPLACE"), nil + case SQLTypeTruncate: + return []byte("TRUNCATE"), nil + case SQLTypeCreate: + return []byte("CREATE"), nil + case SQLTypeDrop: + return []byte("DROP"), nil + case SQLTypeLoad: + return []byte("LOAD"), nil + case SQLTypeMerge: + return []byte("MERGE"), nil + case SQLTypeShow: + return []byte("SHOW"), nil + case SQLTypeAlter: + return []byte("ALTER"), nil + case SQLTypeRename: + return []byte("RENAME"), nil + case SQLTypeDump: + return []byte("DUMP"), nil + case SQLTypeDebug: + return []byte("DEBUG"), nil + case SQLTypeExplain: + return []byte("EXPLAIN"), nil + case SQLTypeDesc: + return []byte("DESC"), nil + case SQLTypeSet: + return []byte("SET"), nil + case SQLTypeReload: + return []byte("RELOAD"), nil + case SQLTypeSelectUnion: + return []byte("SELECT_UNION"), nil + case SQLTypeCreateTable: + return []byte("CREATE_TABLE"), nil + case SQLTypeDropTable: + return []byte("DROP_TABLE"), nil + case SQLTypeAlterTable: + return []byte("ALTER_TABLE"), nil + case SQLTypeSelectFromUpdate: + return []byte("SELECT_FROM_UPDATE"), nil + case SQLTypeMultiDelete: + return []byte("MULTI_DELETE"), nil + case SQLTypeMultiUpdate: + return []byte("MULTI_UPDATE"), nil + case SQLTypeCreateIndex: + return []byte("CREATE_INDEX"), nil + case SQLTypeDropIndex: + return []byte("DROP_INDEX"), nil + case SQLTypeMulti: + return []byte("MULTI"), nil + } + return []byte("INVALID_SQLTYPE"), nil +} diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index 4ff0a38ac..d8bce89ea 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -19,6 +19,7 @@ package types import ( "database/sql/driver" + "fmt" "strings" "github.com/seata/seata-go/pkg/protocol/branch" @@ -42,6 +43,27 @@ const ( IndexTypePrimaryKey IndexType = 1 ) +func (i IndexType) MarshalText() (text []byte, err error) { + switch i { + case IndexTypePrimaryKey: + return []byte("PRIMARY_KEY"), nil + } + return []byte("NULL"), nil +} + +func (i *IndexType) UnmarshalText(text []byte) error { + switch string(text) { + case "PRIMARY_KEY": + *i = IndexTypePrimaryKey + return nil + case "NULL": + *i = IndexTypeNull + return nil + default: + return fmt.Errorf("invalid index type") + } +} + const ( _ DBType = iota DBTypeUnknown @@ -109,8 +131,6 @@ type TransactionContext struct { ResourceID string // BranchID transaction branch unique id BranchID uint64 - // XaID XA id - XaID string // todo delete // XID global transaction id XID string // GlobalLockRequire diff --git a/pkg/datasource/sql/undo/base/undo.go b/pkg/datasource/sql/undo/base/undo.go index 54f8d714a..aee41c008 100644 --- a/pkg/datasource/sql/undo/base/undo.go +++ b/pkg/datasource/sql/undo/base/undo.go @@ -43,7 +43,7 @@ var ( ErrorDeleteUndoLogParamsFault = errors.New("xid or branch_id can't nil") ) -const ( +var ( checkUndoLogTableExistSql = "SELECT 1 FROM " + constant.UndoLogTableName + " LIMIT 1" insertUndoLogSql = "INSERT INTO " + constant.UndoLogTableName + "(branch_id,xid,context,rollback_info,log_status,log_created,log_modified) VALUES (?, ?, ?, ?, ?, now(6), now(6))" ) @@ -89,7 +89,7 @@ func (m *BaseUndoLogManager) InsertUndoLog(record undo.UndologRecord, conn drive if err != nil { return err } - _, err = stmt.Exec([]driver.Value{record.BranchID, record.XID, record.Context, record.RollbackInfo, record.LogStatus}) + _, err = stmt.Exec([]driver.Value{record.BranchID, record.XID, record.Context, record.RollbackInfo, int64(record.LogStatus)}) if err != nil { return err } @@ -187,7 +187,7 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con } // use defalut encode - undoLogContent, err := json.Marshal(branchUndoLog) + rollbackInfo, err := json.Marshal(branchUndoLog) if err != nil { return err } @@ -195,7 +195,7 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con parseContext := make(map[string]string, 0) parseContext[SerializerKey] = "jackson" parseContext[CompressorTypeKey] = "NONE" - rollbackInfo, err := json.Marshal(parseContext) + undoLogContent, err := json.Marshal(parseContext) if err != nil { return err } diff --git a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go index 2da248e2b..dcb19a88b 100644 --- a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go @@ -174,8 +174,8 @@ func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData columnMeta := tableMetaData.Columns[name] keyType := types.IndexTypeNull - if data, ok := tableMetaData.Indexs[name]; ok { - keyType = data.IType + if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok { + keyType = types.IndexTypePrimaryKey } jdbcType := types.GetJDBCTypeByTypeName(columnMeta.ColumnTypeInfo.DatabaseTypeName()) diff --git a/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go index 8e81b65fc..7583fa4da 100644 --- a/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go @@ -100,6 +100,8 @@ func (u *MySQLMultiUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCt return nil, err } + image.SQLType = execCtx.ParseContext.SQLType + return []*types.RecordImage{image}, nil } @@ -130,6 +132,7 @@ func (u *MySQLMultiUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx return nil, err } + image.SQLType = execCtx.ParseContext.SQLType return []*types.RecordImage{image}, nil } diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go index e570c53b9..ba0312991 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go @@ -96,6 +96,7 @@ func (u *MySQLUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty lockKey := u.buildLockKey2(image, *metaData) execCtx.TxCtx.LockKeys[lockKey] = struct{}{} + image.SQLType = execCtx.ParseContext.SQLType return []*types.RecordImage{image}, nil } @@ -136,6 +137,7 @@ func (u *MySQLUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *typ lockKey := u.buildLockKey(rows, *metaData) execCtx.TxCtx.LockKeys[lockKey] = struct{}{} + image.SQLType = execCtx.ParseContext.SQLType return []*types.RecordImage{image}, nil } diff --git a/pkg/datasource/sql/undo/mysql/default.go b/pkg/datasource/sql/undo/mysql/default.go index 039df7c6c..8de1cadb4 100644 --- a/pkg/datasource/sql/undo/mysql/default.go +++ b/pkg/datasource/sql/undo/mysql/default.go @@ -20,10 +20,11 @@ package mysql import ( "github.com/pkg/errors" "github.com/seata/seata-go/pkg/datasource/sql/undo" + "github.com/seata/seata-go/pkg/datasource/sql/undo/base" ) func init() { - if err := undo.RegisterUndoLogManager(&undoLogManager{}); err != nil { + if err := undo.RegisterUndoLogManager(&undoLogManager{Base: base.NewBaseUndoLogManager()}); err != nil { panic(errors.WithStack(err)) } } diff --git a/pkg/datasource/sql/undo/undo.go b/pkg/datasource/sql/undo/undo.go index abe445a2b..633e19416 100644 --- a/pkg/datasource/sql/undo/undo.go +++ b/pkg/datasource/sql/undo/undo.go @@ -142,11 +142,10 @@ func (b *BranchUndoLog) Reverse() { // SQLUndoLog type SQLUndoLog struct { - SQLType types.SQLType - TableName string - Images types.RoundRecordImage - BeforeImage *types.RecordImage - AfterImage *types.RecordImage + SQLType types.SQLType `json:"sqlType"` + TableName string `json:"tableName"` + BeforeImage *types.RecordImage `json:"beforeImage"` + AfterImage *types.RecordImage `json:"afterImage"` } func (s SQLUndoLog) SetTableMeta(tableMeta types.TableMeta) {