diff --git a/go.mod b/go.mod index 02200c851..df1989e79 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -140,6 +141,7 @@ require ( gopkg.in/ini.v1 v1.62.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools v2.2.0+incompatible moul.io/http2curl v1.0.0 // indirect vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect ) diff --git a/go.sum b/go.sum index d45cb0fca..cc28fd98f 100644 --- a/go.sum +++ b/go.sum @@ -1421,6 +1421,7 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/datasource/sql/async_worker.go b/pkg/datasource/sql/async_worker.go index 7702a29da..6d1d4a186 100644 --- a/pkg/datasource/sql/async_worker.go +++ b/pkg/datasource/sql/async_worker.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package sql import ( diff --git a/pkg/datasource/sql/datasource/base/meta_cache.go b/pkg/datasource/sql/datasource/base/meta_cache.go index b5ae04e42..bc0eeb1b5 100644 --- a/pkg/datasource/sql/datasource/base/meta_cache.go +++ b/pkg/datasource/sql/datasource/base/meta_cache.go @@ -19,6 +19,7 @@ package base import ( "context" + "database/sql" "errors" "sync" "time" @@ -29,7 +30,7 @@ import ( type ( // trigger trigger interface { - LoadOne(table string) (types.TableMeta, error) + LoadOne(ctx context.Context, dbName string, table string, conn *sql.Conn) (*types.TableMeta, error) LoadAll() ([]types.TableMeta, error) } @@ -46,13 +47,14 @@ type BaseTableMetaCache struct { expireDuration time.Duration capity int32 size int32 + dbName string cache map[string]*entry cancel context.CancelFunc trigger trigger } // NewBaseCache -func NewBaseCache(capity int32, expireDuration time.Duration, trigger trigger) (*BaseTableMetaCache, error) { +func NewBaseCache(capity int32, dbName string, expireDuration time.Duration, trigger trigger) *BaseTableMetaCache { ctx, cancel := context.WithCancel(context.Background()) c := &BaseTableMetaCache{ @@ -61,15 +63,14 @@ func NewBaseCache(capity int32, expireDuration time.Duration, trigger trigger) ( size: 0, expireDuration: expireDuration, cache: map[string]*entry{}, + dbName: dbName, cancel: cancel, trigger: trigger, } - if err := c.Init(ctx); err != nil { - return nil, err - } + c.Init(ctx) - return c, nil + return c } // init @@ -135,32 +136,31 @@ func (c *BaseTableMetaCache) scanExpire(ctx context.Context) { } // GetTableMeta -func (c *BaseTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { +func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, tableName string, conn *sql.Conn) (types.TableMeta, error) { c.lock.Lock() defer c.lock.Unlock() - v, ok := c.cache[table] - + v, ok := c.cache[tableName] if !ok { - meta, err := c.trigger.LoadOne(table) + meta, err := c.trigger.LoadOne(ctx, c.dbName, tableName, conn) if err != nil { return types.TableMeta{}, err } - if !meta.IsEmpty() { - c.cache[table] = &entry{ - value: meta, + if meta != nil && !meta.IsEmpty() { + c.cache[tableName] = &entry{ + value: *meta, lastAccess: time.Now(), } - return meta, nil + return *meta, nil } return types.TableMeta{}, errors.New("not found table metadata") } v.lastAccess = time.Now() - c.cache[table] = v + c.cache[tableName] = v return v.value, nil } diff --git a/pkg/datasource/sql/datasource/datasource_manager.go b/pkg/datasource/sql/datasource/datasource_manager.go index 9dc42fd9d..739964137 100644 --- a/pkg/datasource/sql/datasource/datasource_manager.go +++ b/pkg/datasource/sql/datasource/datasource_manager.go @@ -176,7 +176,7 @@ type TableMetaCache interface { // Init Init(ctx context.Context, conn *sql.DB) error // GetTableMeta - GetTableMeta(table string) (types.TableMeta, error) + GetTableMeta(ctx context.Context, table string, conn *sql.Conn) (*types.TableMeta, error) // Destroy Destroy() error } diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache.go b/pkg/datasource/sql/datasource/mysql/meta_cache.go index 48df6f03a..e3e517ec5 100644 --- a/pkg/datasource/sql/datasource/mysql/meta_cache.go +++ b/pkg/datasource/sql/datasource/mysql/meta_cache.go @@ -20,13 +20,36 @@ package mysql import ( "context" "database/sql" + "sync" + "time" + + "github.com/pkg/errors" "github.com/seata/seata-go/pkg/datasource/sql/datasource/base" "github.com/seata/seata-go/pkg/datasource/sql/types" ) +var ( + capacity int32 = 1024 + EexpireTime = 15 * time.Minute + tableMetaInstance *tableMetaCache + tableMetaOnce sync.Once + DBName = "seata" +) + type tableMetaCache struct { - cache *base.BaseTableMetaCache + tableMetaCache *base.BaseTableMetaCache +} + +func GetTableMetaInstance() *tableMetaCache { + // Todo constant.DBName get from config + tableMetaOnce.Do(func() { + tableMetaInstance = &tableMetaCache{ + tableMetaCache: base.NewBaseCache(capacity, DBName, EexpireTime, NewMysqlTrigger()), + } + }) + + return tableMetaInstance } // Init @@ -34,9 +57,18 @@ func (c *tableMetaCache) Init(ctx context.Context, conn *sql.DB) error { return nil } -// GetTableMeta -func (c *tableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { - return types.TableMeta{}, nil +// GetTableMeta get table info from cache or information schema +func (c *tableMetaCache) GetTableMeta(ctx context.Context, tableName string, conn *sql.Conn) (*types.TableMeta, error) { + if tableName == "" { + return nil, errors.New("TableMeta cannot be fetched without tableName") + } + + tableMeta, err := c.tableMetaCache.GetTableMeta(ctx, tableName, conn) + if err != nil { + return nil, err + } + + return &tableMeta, nil } // Destroy diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache_test.go b/pkg/datasource/sql/datasource/mysql/meta_cache_test.go new file mode 100644 index 000000000..df0c46417 --- /dev/null +++ b/pkg/datasource/sql/datasource/mysql/meta_cache_test.go @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mysql + +import ( + "context" + "database/sql" + _ "github.com/go-sql-driver/mysql" + "gotest.tools/assert" + "testing" +) + +// TestGetTableMeta +func TestGetTableMeta(t *testing.T) { + // local test can annotation t.SkipNow() + t.SkipNow() + + testTableMeta := func() { + metaInstance := GetTableMetaInstance() + + db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/seata?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + ctx := context.Background() + conn, _ := db.Conn(ctx) + + tableMeta, err := metaInstance.GetTableMeta(ctx, "undo_log", conn) + assert.NilError(t, err) + + t.Logf("%+v", tableMeta) + } + + t.Run("testTableMeta", func(t *testing.T) { + testTableMeta() + }) +} diff --git a/pkg/datasource/sql/datasource/mysql/trigger.go b/pkg/datasource/sql/datasource/mysql/trigger.go new file mode 100644 index 000000000..538839766 --- /dev/null +++ b/pkg/datasource/sql/datasource/mysql/trigger.go @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mysql + +import ( + "context" + "database/sql" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/undo/executor" +) + +type mysqlTrigger struct { +} + +func NewMysqlTrigger() *mysqlTrigger { + return &mysqlTrigger{} +} + +// LoadOne get table meta column and index +func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName string, conn *sql.Conn) (*types.TableMeta, error) { + tableMeta := types.TableMeta{ + Name: tableName, + Columns: make(map[string]types.ColumnMeta), + Indexs: make(map[string]types.IndexMeta), + } + + colMetas, err := m.getColumns(ctx, dbName, tableName, conn) + if err != nil { + return nil, errors.Wrapf(err, "Could not found any column in the table: %s", tableName) + } + + var columns []string + for _, column := range colMetas { + tableMeta.Columns[column.ColumnName] = column + columns = append(columns, column.ColumnName) + } + tableMeta.ColumnNames = columns + + indexes, err := m.getIndexes(ctx, dbName, tableName, conn) + if err != nil { + return nil, errors.Wrapf(err, "Could not found any index in the table: %s", tableName) + } + for _, index := range indexes { + col := tableMeta.Columns[index.ColumnName] + idx, ok := tableMeta.Indexs[index.Name] + if ok { + idx.Values = append(idx.Values, col) + } else { + index.Values = append(index.Values, col) + tableMeta.Indexs[index.Name] = index + } + } + if len(tableMeta.Indexs) == 0 { + return nil, errors.Errorf("Could not found any index in the table: %s", tableName) + } + + return &tableMeta, nil +} + +// LoadAll +func (m *mysqlTrigger) LoadAll() ([]types.TableMeta, error) { + return []types.TableMeta{}, nil +} + +// getColumns get tableMeta column +func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table string, conn *sql.Conn) ([]types.ColumnMeta, error) { + table = executor.DelEscape(table, types.DBTypeMySQL) + + var result []types.ColumnMeta + + columnSchemaSql := "select TABLE_CATALOG, TABLE_NAME, TABLE_SCHEMA, COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, COLUMN_KEY, " + + " IS_NULLABLE, EXTRA from INFORMATION_SCHEMA.COLUMNS where `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + + stmt, err := conn.PrepareContext(ctx, columnSchemaSql) + if err != nil { + return nil, err + } + + rows, err := stmt.QueryContext(ctx, dbName, table) + if err != nil { + return nil, err + } + + for rows.Next() { + var ( + tableCatalog string + tableName string + tableSchema string + columnName string + dataType string + columnType string + columnKey string + isNullable string + extra string + ) + + col := types.ColumnMeta{} + + if err = rows.Scan( + &tableCatalog, + &tableName, + &tableSchema, + &columnName, + &dataType, + &columnType, + &columnKey, + &isNullable, + &extra); err != nil { + return nil, err + } + + col.Schema = tableSchema + col.Table = tableName + col.ColumnName = strings.Trim(columnName, "` ") + col.DataType = types.GetSqlDataType(dataType) + col.ColumnType = columnType + col.ColumnKey = columnKey + if strings.ToLower(isNullable) == "yes" { + col.IsNullable = 1 + } else { + col.IsNullable = 0 + } + col.Extra = extra + col.Autoincrement = strings.Contains(strings.ToLower(extra), "auto_increment") + + result = append(result, col) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + if err = rows.Close(); err != nil { + return nil, err + } + + if len(result) == 0 { + return nil, errors.New("can't find column") + } + + return result, nil +} + +// getIndex get tableMetaIndex +func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName string, conn *sql.Conn) ([]types.IndexMeta, error) { + tableName = executor.DelEscape(tableName, types.DBTypeMySQL) + + result := make([]types.IndexMeta, 0) + + indexSchemaSql := "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE`, `INDEX_TYPE`, `COLLATION`, `CARDINALITY` " + + "FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + + stmt, err := conn.PrepareContext(ctx, indexSchemaSql) + if err != nil { + return nil, err + } + + rows, err := stmt.QueryContext(ctx, dbName, tableName) + if err != nil { + return nil, err + } + + defer rows.Close() + + for rows.Next() { + var ( + indexName, columnName, nonUnique, indexType, collation string + cardinality int + ) + + if err = rows.Scan( + &indexName, + &columnName, + &nonUnique, + &indexType, + &collation, + &cardinality); err != nil { + return nil, err + } + + index := types.IndexMeta{ + Schema: dbName, + Table: tableName, + Name: indexName, + ColumnName: columnName, + Values: make([]types.ColumnMeta, 0), + } + + if nonUnique == "1" || "yes" == strings.ToLower(nonUnique) { + index.NonUnique = true + } + + if "primary" == strings.ToLower(indexName) { + index.IType = types.IndexPrimary + } else if !index.NonUnique { + index.IType = types.IndexUnique + } else { + index.IType = types.IndexNormal + } + + result = append(result, index) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} diff --git a/pkg/datasource/sql/types/meta.go b/pkg/datasource/sql/types/meta.go index 8ad8b2048..05081ff1e 100644 --- a/pkg/datasource/sql/types/meta.go +++ b/pkg/datasource/sql/types/meta.go @@ -17,9 +17,7 @@ package types -import ( - "database/sql" -) +import "database/sql" // ColumnMeta type ColumnMeta struct { @@ -27,10 +25,16 @@ type ColumnMeta struct { Schema string // Table Table string - // Autoincrement - Autoincrement bool // Info Info sql.ColumnType + // Autoincrement + Autoincrement bool + ColumnName string + ColumnType string + DataType int32 + ColumnKey string + IsNullable int8 + Extra string } // IndexMeta @@ -38,9 +42,10 @@ type IndexMeta struct { // Schema Schema string // Table - Table string - // Name - Name string + Table string + Name string + ColumnName string + NonUnique bool // IType IType IndexType // Values @@ -56,7 +61,8 @@ type TableMeta struct { // Columns Columns map[string]ColumnMeta // Indexs - Indexs map[string]IndexMeta + Indexs map[string]IndexMeta + ColumnNames []string } func (m TableMeta) IsEmpty() bool { diff --git a/pkg/datasource/sql/types/mysql_ketword_checker.go b/pkg/datasource/sql/types/mysql_ketword_checker.go new file mode 100644 index 000000000..3dd614cf2 --- /dev/null +++ b/pkg/datasource/sql/types/mysql_ketword_checker.go @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +var MysqlKeyWord map[string]string + +func GetMysqlKeyWord() map[string]string { + // lazy init + if MysqlKeyWord == nil { + MysqlKeyWord = map[string]string{ + "ACCESSIBLE": "ACCESSIBLE", + "ADD": "ADD", + "ALL": "ALL", + "ALTER": "ALTER", + "ANALYZE": "ANALYZE", + "AND": "AND", + "ARRAY": "ARRAY", + "AS": "AS", + "ASC": "ASC", + "ASENSITIVE": "ASENSITIVE", + "BEFORE": "BEFORE", + "BETWEEN": "BETWEEN", + "BIGINT": "BIGINT", + "BINARY": "BINARY", + "BLOB": "BLOB", + "BOTH": "BOTH", + "BY": "BY", + "CALL": "CALL", + "CASCADE": "CASCADE", + "CASE": "CASE", + "CHANGE": "CHANGE", + "CHAR": "CHAR", + "CHARACTER": "CHARACTER", + "CHECK": "CHECK", + "COLLATE": "COLLATE", + "COLUMN": "COLUMN", + "CONDITION": "CONDITION", + "CONSTRAINT": "CONSTRAINT", + "CONTINUE": "CONTINUE", + "CONVERT": "CONVERT", + "CREATE": "CREATE", + "CROSS": "CROSS", + "CUBE": "CUBE", + "CUME_DIST": "CUME_DIST", + "CURRENT_DATE": "CURRENT_DATE", + "CURRENT_TIME": "CURRENT_TIME", + "CURRENT_TIMESTAMP": "CURRENT_TIMESTAMP", + "CURRENT_USER": "CURRENT_USER", + "CURSOR": "CURSOR", + "DATABASE": "DATABASE", + "DATABASES": "DATABASES", + "DAY_HOUR": "DAY_HOUR", + "DAY_MICROSECOND": "DAY_MICROSECOND", + "DAY_MINUTE": "DAY_MINUTE", + "DAY_SECOND": "DAY_SECOND", + "DEC": "DEC", + "DECIMAL": "DECIMAL", + "DECLARE": "DECLARE", + "DEFAULT": "DEFAULT", + "DELAYED": "DELAYED", + "DELETE": "DELETE", + "DENSE_RANK": "DENSE_RANK", + "DESC": "DESC", + "DESCRIBE": "DESCRIBE", + "DETERMINISTIC": "DETERMINISTIC", + "DISTINCT": "DISTINCT", + "DISTINCTROW": "DISTINCTROW", + "DIV": "DIV", + "DOUBLE": "DOUBLE", + "DROP": "DROP", + "DUAL": "DUAL", + "EACH": "EACH", + "ELSE": "ELSE", + "ELSEIF": "ELSEIF", + "EMPTY": "EMPTY", + "ENCLOSED": "ENCLOSED", + "ESCAPED": "ESCAPED", + "EXCEPT": "EXCEPT", + "EXISTS": "EXISTS", + "EXIT": "EXIT", + "EXPLAIN": "EXPLAIN", + "FALSE": "FALSE", + "FETCH": "FETCH", + "FIRST_VALUE": "FIRST_VALUE", + "FLOAT": "FLOAT", + "FLOAT4": "FLOAT4", + "FLOAT8": "FLOAT8", + "FOR": "FOR", + "FORCE": "FORCE", + "FOREIGN": "FOREIGN", + "FROM": "FROM", + "FULLTEXT": "FULLTEXT", + "FUNCTION": "FUNCTION", + "GENERATED": "GENERATED", + "GET": "GET", + "GRANT": "GRANT", + "GROUP": "GROUP", + "GROUPING": "GROUPING", + "GROUPS": "GROUPS", + "HAVING": "HAVING", + "HIGH_PRIORITY": "HIGH_PRIORITY", + "HOUR_MICROSECOND": "HOUR_MICROSECOND", + "HOUR_MINUTE": "HOUR_MINUTE", + "HOUR_SECOND": "HOUR_SECOND", + "IF": "IF", + "IGNORE": "IGNORE", + "IN": "IN", + "INDEX": "INDEX", + "INFILE": "INFILE", + "INNER": "INNER", + "INOUT": "INOUT", + "INSENSITIVE": "INSENSITIVE", + "INSERT": "INSERT", + "INT": "INT", + "INT1": "INT1", + "INT2": "INT2", + "INT3": "INT3", + "INT4": "INT4", + "INT8": "INT8", + "INTEGER": "INTEGER", + "INTERVAL": "INTERVAL", + "INTO": "INTO", + "IO_AFTER_GTIDS": "IO_AFTER_GTIDS", + "IO_BEFORE_GTIDS": "IO_BEFORE_GTIDS", + "IS": "IS", + "ITERATE": "ITERATE", + "JOIN": "JOIN", + "JSON_TABLE": "JSON_TABLE", + "KEY": "KEY", + "KEYS": "KEYS", + "KILL": "KILL", + "LAG": "LAG", + "LAST_VALUE": "LAST_VALUE", + "LATERAL": "LATERAL", + "LEAD": "LEAD", + "LEADING": "LEADING", + "LEAVE": "LEAVE", + "LEFT": "LEFT", + "LIKE": "LIKE", + "LIMIT": "LIMIT", + "LINEAR": "LINEAR", + "LINES": "LINES", + "LOAD": "LOAD", + "LOCALTIME": "LOCALTIME", + "LOCALTIMESTAMP": "LOCALTIMESTAMP", + "LOCK": "LOCK", + "LONG": "LONG", + "LONGBLOB": "LONGBLOB", + "LONGTEXT": "LONGTEXT", + "LOOP": "LOOP", + "LOW_PRIORITY": "LOW_PRIORITY", + "MASTER_BIND": "MASTER_BIND", + "MASTER_SSL_VERIFY_SERVER_CERT": "MASTER_SSL_VERIFY_SERVER_CERT", + "MATCH": "MATCH", + "MAXVALUE": "MAXVALUE", + "MEDIUMBLOB": "MEDIUMBLOB", + "MEDIUMINT": "MEDIUMINT", + "MEDIUMTEXT": "MEDIUMTEXT", + "MEMBER": "MEMBER", + "MIDDLEINT": "MIDDLEINT", + "MINUTE_MICROSECOND": "MINUTE_MICROSECOND", + "MINUTE_SECOND": "MINUTE_SECOND", + "MOD": "MOD", + "MODIFIES": "MODIFIES", + "NATURAL": "NATURAL", + "NOT": "NOT", + "NO_WRITE_TO_BINLOG": "NO_WRITE_TO_BINLOG", + "NTH_VALUE": "NTH_VALUE", + "NTILE": "NTILE", + "NULL": "NULL", + "NUMERIC": "NUMERIC", + "OF": "OF", + "ON": "ON", + "OPTIMIZE": "OPTIMIZE", + "OPTIMIZER_COSTS": "OPTIMIZER_COSTS", + "OPTION": "OPTION", + "OPTIONALLY": "OPTIONALLY", + "OR": "OR", + "ORDER": "ORDER", + "OUT": "OUT", + "OUTER": "OUTER", + "OUTFILE": "OUTFILE", + "OVER": "OVER", + "PARTITION": "PARTITION", + "PERCENT_RANK": "PERCENT_RANK", + "PRECISION": "PRECISION", + "PRIMARY": "PRIMARY", + "PROCEDURE": "PROCEDURE", + "PURGE": "PURGE", + "RANGE": "RANGE", + "RANK": "RANK", + "READ": "READ", + "READS": "READS", + "READ_WRITE": "READ_WRITE", + "REAL": "REAL", + "RECURSIVE": "RECURSIVE", + "REFERENCES": "REFERENCES", + "REGEXP": "REGEXP", + "RELEASE": "RELEASE", + "RENAME": "RENAME", + "REPEAT": "REPEAT", + "REPLACE": "REPLACE", + "REQUIRE": "REQUIRE", + "RESIGNAL": "RESIGNAL", + "RESTRICT": "RESTRICT", + "RETURN": "RETURN", + "REVOKE": "REVOKE", + "RIGHT": "RIGHT", + "RLIKE": "RLIKE", + "ROW": "ROW", + "ROWS": "ROWS", + "ROW_NUMBER": "ROW_NUMBER", + "SCHEMA": "SCHEMA", + "SCHEMAS": "SCHEMAS", + "SECOND_MICROSECOND": "SECOND_MICROSECOND", + "SELECT": "SELECT", + "SENSITIVE": "SENSITIVE", + "SEPARATOR": "SEPARATOR", + "SET": "SET", + "SHOW": "SHOW", + "SIGNAL": "SIGNAL", + "SMALLINT": "SMALLINT", + "SPATIAL": "SPATIAL", + "SPECIFIC": "SPECIFIC", + "SQL": "SQL", + "SQLEXCEPTION": "SQLEXCEPTION", + "SQLSTATE": "SQLSTATE", + "SQLWARNING": "SQLWARNING", + "SQL_BIG_RESULT": "SQL_BIG_RESULT", + "SQL_CALC_FOUND_ROWS": "SQL_CALC_FOUND_ROWS", + "SQL_SMALL_RESULT": "SQL_SMALL_RESULT", + "SSL": "SSL", + "STARTING": "STARTING", + "STORED": "STORED", + "STRAIGHT_JOIN": "STRAIGHT_JOIN", + "SYSTEM": "SYSTEM", + "TABLE": "TABLE", + "TERMINATED": "TERMINATED", + "THEN": "THEN", + "TINYBLOB": "TINYBLOB", + "TINYINT": "TINYINT", + "TINYTEXT": "TINYTEXT", + "TO": "TO", + "TRAILING": "TRAILING", + "TRIGGER": "TRIGGER", + "TRUE": "TRUE", + "UNDO": "UNDO", + "UNION": "UNION", + "UNIQUE": "UNIQUE", + "UNLOCK": "UNLOCK", + "UNSIGNED": "UNSIGNED", + "UPDATE": "UPDATE", + "USAGE": "USAGE", + "USE": "USE", + "USING": "USING", + "UTC_DATE": "UTC_DATE", + "UTC_TIME": "UTC_TIME", + "UTC_TIMESTAMP": "UTC_TIMESTAMP", + "VALUES": "VALUES", + "VARBINARY": "VARBINARY", + "VARCHAR": "VARCHAR", + "VARCHARACTER": "VARCHARACTER", + "VARYING": "VARYING", + "VIRTUAL": "VIRTUAL", + "WHEN": "WHEN", + "WHERE": "WHERE", + "WHILE": "WHILE", + "WINDOW": "WINDOW", + "WITH": "WITH", + "WRITE": "WRITE", + "XOR": "XOR", + "YEAR_MONTH": "YEAR_MONTH", + "ZEROFILL": "ZEROFILL", + } + } + + return MysqlKeyWord +} diff --git a/pkg/datasource/sql/types/sql_data_type.go b/pkg/datasource/sql/types/sql_data_type.go new file mode 100644 index 000000000..e5659cf82 --- /dev/null +++ b/pkg/datasource/sql/types/sql_data_type.go @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import "strings" + +var SqlDataTypes = map[string]int32{ + "BIT": -7, + "TINYINT": -6, + "SMALLINT": 5, + "INTEGER": 4, + "BIGINT": -5, + "FLOAT": 6, + "REAL": 7, + "DOUBLE": 8, + "NUMERIC": 2, + "DECIMAL": 3, + "CHAR": 1, + "VARCHAR": 12, + "LONGVARCHAR": -1, + "DATE": 91, + "TIME": 92, + "TIMESTAMP": 93, + "BINARY": -2, + "VARBINARY": -3, + "LONGVARBINARY": -4, + "NULL": 0, + "OTHER": 1111, + "JAVA_OBJECT": 2000, + "DISTINCT": 2001, + "STRUCT": 2002, + "ARRAY": 2003, + "BLOB": 2004, + "CLOB": 2005, + "REF": 2006, + "DATALINK": 70, + "BOOLEAN": 16, + "ROWID": -8, + "NCHAR": -15, + "NVARCHAR": -9, + "LONGNVARCHAR": -16, + "NCLOB": 2011, + "SQLXML": 2009, + "REF_CURSOR": 2012, + "TIME_WITH_TIMEZONE": 2013, + "TIMESTAMP_WITH_TIMEZONE": 2014, +} + +func GetSqlDataType(dataType string) int32 { + return SqlDataTypes[strings.ToUpper(dataType)] +} diff --git a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go index 7f893177c..ab0fded26 100644 --- a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package builder import ( diff --git a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder_test.go index 35f126b3f..d8db7027f 100644 --- a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder_test.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package builder import ( diff --git a/pkg/datasource/sql/undo/executor/sql.go b/pkg/datasource/sql/undo/executor/sql.go new file mode 100644 index 000000000..d9b0a48d6 --- /dev/null +++ b/pkg/datasource/sql/undo/executor/sql.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "strings" + + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + +const ( + Dot = "." + EscapeStandard = "\"" + EscapeMysql = "`" +) + +// DelEscape del escape by db type +func DelEscape(colName string, dbType types.DBType) string { + newColName := delEscape(colName, EscapeStandard) + if dbType == types.DBTypeMySQL { + newColName = delEscape(newColName, EscapeMysql) + } + + return newColName +} + +// delEscape +func delEscape(colName string, escape string) string { + if colName == "" { + return "" + } + + if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape { + // like "scheme"."id" `scheme`.`id` + str := escape + Dot + escape + index := strings.Index(colName, str) + if index > -1 { + return colName[1:index] + Dot + colName[index+len(str):len(colName)-1] + } + + return colName[1 : len(colName)-1] + } else { + // like "scheme".id `scheme`.id + str := escape + Dot + index := strings.Index(colName, str) + if index > -1 && string(colName[0]) == escape { + return colName[1:index] + Dot + colName[index+len(str):] + } + + // like scheme."id" scheme.`id` + str = Dot + escape + index = strings.Index(colName, str) + if index > -1 && string(colName[len(colName)-1]) == escape { + return colName[0:index] + Dot + colName[index+len(str):len(colName)-1] + } + } + + return colName +} + +// AddEscape if necessary, add escape by db type +func AddEscape(colName string, dbType types.DBType) string { + if dbType == types.DBTypeMySQL { + return addEscape(colName, dbType, EscapeMysql) + } + + return addEscape(colName, dbType, EscapeStandard) +} + +func addEscape(colName string, dbType types.DBType, escape string) string { + if colName == "" { + return colName + } + + if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape { + return colName + } + + if !checkEscape(colName, dbType) { + return colName + } + + if strings.Contains(colName, Dot) { + // like "scheme".id `scheme`.id + str := escape + Dot + dotIndex := strings.Index(colName, str) + if dotIndex > -1 { + tempStr := strings.Builder{} + tempStr.WriteString(colName[0 : dotIndex+len(str)]) + tempStr.WriteString(escape) + tempStr.WriteString(colName[dotIndex+len(str):]) + tempStr.WriteString(escape) + + return tempStr.String() + } + + // like scheme."id" scheme.`id` + str = Dot + escape + dotIndex = strings.Index(colName, str) + if dotIndex > -1 { + tempStr := strings.Builder{} + tempStr.WriteString(escape) + tempStr.WriteString(colName[0:dotIndex]) + tempStr.WriteString(escape) + tempStr.WriteString(colName[dotIndex:]) + + return tempStr.String() + } + + str = Dot + dotIndex = strings.Index(colName, str) + if dotIndex > -1 { + tempStr := strings.Builder{} + tempStr.WriteString(escape) + tempStr.WriteString(colName[0:dotIndex]) + tempStr.WriteString(escape) + tempStr.WriteString(Dot) + tempStr.WriteString(escape) + tempStr.WriteString(colName[dotIndex+len(str):]) + tempStr.WriteString(escape) + + return tempStr.String() + } + } + + buf := make([]byte, len(colName)+2) + buf[0], buf[len(buf)-1] = escape[0], escape[0] + + for key, _ := range colName { + buf[key+1] = colName[key] + } + + return string(buf) +} + +// checkEscape check whether given field or table name use keywords. the method has database special logic. +func checkEscape(colName string, dbType types.DBType) bool { + switch dbType { + case types.DBTypeMySQL: + if _, ok := types.GetMysqlKeyWord()[strings.ToUpper(colName)]; ok { + return true + } + + return false + // TODO impl Oracle PG SQLServer ... + default: + return true + } +} diff --git a/pkg/datasource/sql/undo/executor/sql_test.go b/pkg/datasource/sql/undo/executor/sql_test.go new file mode 100644 index 000000000..6acde67a9 --- /dev/null +++ b/pkg/datasource/sql/undo/executor/sql_test.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package executor + +import ( + "log" + "testing" + + "github.com/seata/seata-go/pkg/datasource/sql/types" + + "github.com/stretchr/testify/assert" +) + +// TestDelEscape +func TestDelEscape(t *testing.T) { + strSlice := []string{`"scheme"."id"`, "`scheme`.`id`", `"scheme".id`, `scheme."id"`, `scheme."id"`, "scheme.`id`"} + + for k, v := range strSlice { + res := DelEscape(v, types.DBTypeMySQL) + log.Printf("val_%d: %s, res_%d: %s\n", k, v, k, res) + assert.Equal(t, "scheme.id", res) + } +} + +// TestAddEscape +func TestAddEscape(t *testing.T) { + strSlice := []string{`"scheme".id`, "`scheme`.id", `scheme."id"`, "scheme.`id`"} + + for k, v := range strSlice { + res := AddEscape(v, types.DBTypeMySQL) + log.Printf("val_%d: %s, res_%d: %s\n", k, v, k, res) + assert.Equal(t, v, res) + } + + strSlice1 := []string{"ALTER", "ANALYZE"} + for k, v := range strSlice1 { + res := AddEscape(v, types.DBTypeMySQL) + log.Printf("val_%d: %s, res_%d: %s\n", k, v, k, res) + assert.Equal(t, "`"+v+"`", res) + } +} diff --git a/pkg/util/fanout/fanout.go b/pkg/util/fanout/fanout.go index c3c449c62..a111c6b96 100644 --- a/pkg/util/fanout/fanout.go +++ b/pkg/util/fanout/fanout.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package fanout import ( diff --git a/pkg/util/fanout/fanout_test.go b/pkg/util/fanout/fanout_test.go index 66afeae44..040358e06 100644 --- a/pkg/util/fanout/fanout_test.go +++ b/pkg/util/fanout/fanout_test.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package fanout import (