Skip to content

Commit

Permalink
enable AutoMigrate when using non default sqlserver schemas (#50)
Browse files Browse the repository at this point in the history
* enable AutoMigrate when using custom  sqlserver schemas

Sqlserver supports multiple schemas per database. This feature works in gorm (custom TableName/NamingStrategy) but fails at AutoMigrate.
This commit fixes the migrator to also support non-default schemas.

* fix constraint lookup when using explicit schema names

* bugfixes due to upgrade

* fix bug in foreign-key query

* Update migrator_test.go

Co-authored-by: Mueller Manuel (LWE) <[email protected]>
Co-authored-by: Jinzhu <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2022
1 parent 88a3f7b commit 07d0daa
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 7 deletions.
62 changes: 55 additions & 7 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)

type Migrator struct {
Expand All @@ -19,12 +20,46 @@ func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT table_name FROM INFORMATION_SCHEMA.tables WHERE table_catalog = ?", m.CurrentDatabase()).Scan(&tableList).Error
}

func getTableSchemaName(schema *schema.Schema) string {
//return the schema name if it is explicitly provided in the table name
//otherwise return a sql wildcard -> use any table_schema
if schema == nil || !strings.Contains(schema.Table, ".") {
return ""
}
_, schemaName, _ := splitFullQualifiedName(schema.Table)
return schemaName
}

func splitFullQualifiedName(name string) (string, string, string) {
nameParts := strings.Split(name, ".")
if len(nameParts) == 1 { //[table_name]
return "", "", nameParts[0]
} else if len(nameParts) == 2 { //[table_schema].[table_name]
return "", nameParts[0], nameParts[1]
} else if len(nameParts) == 3 { //[table_catalog].[table_schema].[table_name]
return nameParts[0], nameParts[1], nameParts[2]
}
return "", "", ""
}

func getFullQualifiedTableName(stmt *gorm.Statement) string {
fullQualifiedTableName := stmt.Table
if schemaName := getTableSchemaName(stmt.Schema); schemaName != "" {
fullQualifiedTableName = schemaName + "." + fullQualifiedTableName
}
return fullQualifiedTableName
}

func (m Migrator) HasTable(value interface{}) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
schemaName := getTableSchemaName(stmt.Schema)
if schemaName == "" {
schemaName = "%"
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
stmt.Table, m.CurrentDatabase(),
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ? and table_schema like ? AND table_type = ?",
stmt.Table, m.CurrentDatabase(), schemaName, "BASE TABLE",
).Row().Scan(&count)
})
return count > 0
Expand All @@ -40,7 +75,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
Parent string
}
var constraints []constraint
err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", stmt.Table).Scan(&constraints).Error
err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", getFullQualifiedTableName(stmt)).Scan(&constraints).Error

for _, c := range constraints {
if err == nil {
Expand Down Expand Up @@ -150,7 +185,7 @@ var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$")
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
rows, err := m.DB.Session(&gorm.Session{}).Table(getFullQualifiedTableName(stmt)).Limit(1).Rows()
if err != nil {
return err
}
Expand Down Expand Up @@ -259,7 +294,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {

return m.DB.Raw(
"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)",
name, stmt.Table,
name, getFullQualifiedTableName(stmt),
).Row().Scan(&count)
})
return count > 0
Expand All @@ -285,9 +320,17 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
name = chk.Name
}

tableCatalog, schema, tableName := splitFullQualifiedName(table)
if tableCatalog == "" {
tableCatalog = m.CurrentDatabase()
}
if schema == "" {
schema = "%"
}

return m.DB.Raw(
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`,
name, table, m.CurrentDatabase(),
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
name, tableName, schema, tableCatalog,
).Row().Scan(&count)
})
return count > 0
Expand All @@ -297,3 +340,8 @@ func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
return
}

func (m Migrator) DefaultSchema() (name string) {
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
return
}
98 changes: 98 additions & 0 deletions migrator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package sqlserver_test

import (
"os"
"testing"

"gorm.io/driver/sqlserver"
"gorm.io/gorm"
)

var sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"

func init() {
if dbDSN := os.Getenv("GORM_DSN"); dbDSN != "" {
sqlserverDSN = dbDSN
}
}

type Testtable struct {
Test uint64 `gorm:"index"`
}

type Testtable2 struct {
Test uint64 `gorm:"index"`
Test2 uint64
}

func (*Testtable2) TableName() string { return "testtables" }

type Testtable3 struct {
Test3 uint64
}

func (*Testtable3) TableName() string { return "testschema1.Testtables" }

type Testtable4 struct {
Test4 uint64
}

func (*Testtable4) TableName() string { return "testschema2.Testtables" }

type Testtable5 struct {
Test4 uint64
Test5 uint64 `gorm:"index"`
}

func (*Testtable5) TableName() string { return "testschema2.Testtables" }

func TestAutomigrateTablesWithoutDefaultSchema(t *testing.T) {
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
if err != nil {
t.Error(err)
}

if tx := db.Exec("create schema testschema1"); tx.Error != nil {
t.Error("couldn't create schema testschema1", tx.Error)
}
if tx := db.Exec("create schema testschema2"); tx.Error != nil {
t.Error("couldn't create schema testschema2", tx.Error)
}

if err = db.AutoMigrate(&Testtable{}); err != nil {
t.Error("couldn't create a table at user default schema", err)
}
if err = db.AutoMigrate(&Testtable2{}); err != nil {
t.Error("couldn't update a table at user default schema", err)
}
if err = db.AutoMigrate(&Testtable3{}); err != nil {
t.Error("couldn't create a table at schema testschema1", err)
}
if err = db.AutoMigrate(&Testtable4{}); err != nil {
t.Error("couldn't create a table at schema testschema2", err)
}
if err = db.AutoMigrate(&Testtable5{}); err != nil {
t.Error("couldn't update a table at schema testschema2", err)
}

if tx := db.Exec("drop table testtables"); tx.Error != nil {
t.Error("couldn't drop table testtable at user default schema", tx.Error)
}

if tx := db.Exec("drop table testschema1.testtables"); tx.Error != nil {
t.Error("couldn't drop table testschema1.testtable", tx.Error)
}

if tx := db.Exec("drop table testschema2.testtables"); tx.Error != nil {
t.Error("couldn't drop table testschema2.testtable", tx.Error)
}

if tx := db.Exec("drop schema testschema1"); tx.Error != nil {
t.Error("couldn't drop schema testschema1", tx.Error)
}

if tx := db.Exec("drop schema testschema2"); tx.Error != nil {
t.Error("couldn't drop schema testschema2", tx.Error)
}

}

0 comments on commit 07d0daa

Please sign in to comment.