diff --git a/migrator/migrator.go b/migrator/migrator.go index 64a5a4b52..d97fbf35c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -110,15 +110,20 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } +func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { + queryTx = m.DB.Session(&gorm.Session{}) + execTx = queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + return queryTx, execTx +} + // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - queryTx := m.DB.Session(&gorm.Session{}) - execTx := queryTx - if m.DB.DryRun { - queryTx.DryRun = false - execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) - } + queryTx, execTx := m.GetQueryAndExecTx() if !queryTx.Migrator().HasTable(value) { if err := execTx.Migrator().CreateTable(value); err != nil { return err @@ -268,7 +273,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) + sql, vars := constraint.Build() createTableSQL += sql + "," values = append(values, vars...) } @@ -276,6 +281,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } + for _, uni := range stmt.Schema.ParseUniqueConstraints() { + createTableSQL += "CONSTRAINT ? UNIQUE (?)," + values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) + } + for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) @@ -439,6 +449,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if field.IgnoreMigration { + return nil + } + // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -499,7 +513,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } // check unique - if unique, ok := columnType.Unique(); ok && unique != field.Unique { + if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") { // not primary key if !field.PrimaryKey { alterColumn = true @@ -630,37 +644,36 @@ func (m Migrator) DropView(name string) error { return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } -func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { - sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) - return -} - // GuessConstraintAndTable guess statement's constraint and it's table based on name -func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { +// +// Deprecated: use GuessConstraintInterfaceAndTable instead. +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + switch c := constraint.(type) { + case *schema.Constraint: + return c, nil, table + case *schema.CheckConstraint: + return nil, c, table + default: + return nil, nil, table + } +} + +// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name +// nolint:cyclop +func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { - return nil, nil, stmt.Table + return nil, stmt.Table } checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return nil, &chk, stmt.Table + return &chk, stmt.Table + } + + uniqueConstraints := stmt.Schema.ParseUniqueConstraints() + if uni, ok := uniqueConstraints[name]; ok { + return &uni, stmt.Table } getTable := func(rel *schema.Relationship) string { @@ -675,7 +688,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } @@ -683,40 +696,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for k := range checkConstraints { if checkConstraints[k].Field == field { v := checkConstraints[k] - return nil, &v, stmt.Table + return &v, stmt.Table + } + } + + for k := range uniqueConstraints { + if uniqueConstraints[k].Field == field { + v := uniqueConstraints[k] + return &v, stmt.Table } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } } - return nil, nil, stmt.Schema.Table + return nil, stmt.Schema.Table } // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - if chk != nil { - return m.DB.Exec( - "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, - ).Error - } - + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } - sql, values := buildConstraint(constraint) + sql, values := constraint.Build() return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - return nil }) } @@ -724,11 +736,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) @@ -739,11 +749,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Raw( diff --git a/schema/check.go b/schema/check.go deleted file mode 100644 index 89e732d36..000000000 --- a/schema/check.go +++ /dev/null @@ -1,35 +0,0 @@ -package schema - -import ( - "regexp" - "strings" -) - -// reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") - -type Check struct { - Name string - Constraint string // length(phone) >= 10 - *Field -} - -// ParseCheckConstraints parse schema check constraints -func (schema *Schema) ParseCheckConstraints() map[string]Check { - checks := map[string]Check{} - for _, field := range schema.FieldsByDBName { - if chk := field.TagSettings["CHECK"]; chk != "" { - names := strings.Split(chk, ",") - if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { - checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} - } else { - if names[0] == "" { - chk = strings.Join(names[1:], ",") - } - name := schema.namer.CheckerName(schema.Table, field.DBName) - checks[name] = Check{Name: name, Constraint: chk, Field: field} - } - } - } - return checks -} diff --git a/schema/constraint.go b/schema/constraint.go new file mode 100644 index 000000000..5f6beb89c --- /dev/null +++ b/schema/constraint.go @@ -0,0 +1,66 @@ +package schema + +import ( + "regexp" + "strings" + + "gorm.io/gorm/clause" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") + +type CheckConstraint struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +func (chk *CheckConstraint) GetName() string { return chk.Name } + +func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { + checks := map[string]CheckConstraint{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} + +type UniqueConstraint struct { + Name string + Field *Field +} + +func (uni *UniqueConstraint) GetName() string { return uni.Name } + +func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} +} + +// ParseUniqueConstraints parse schema unique constraints +func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { + uniques := make(map[string]UniqueConstraint) + for _, field := range schema.Fields { + if field.Unique { + name := schema.namer.UniqueName(schema.Table, field.DBName) + uniques[name] = UniqueConstraint{Name: name, Field: field} + } + } + return uniques +} diff --git a/schema/check_test.go b/schema/constraint_test.go similarity index 59% rename from schema/check_test.go rename to schema/constraint_test.go index eda043b7f..6fcb1b855 100644 --- a/schema/check_test.go +++ b/schema/constraint_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserCheck struct { @@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) { t.Fatalf("failed to parse user check, got error %v", err) } - results := map[string]schema.Check{ + results := map[string]schema.CheckConstraint{ "name_checker": { Name: "name_checker", Constraint: "name <> 'jinzhu'", @@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) { } } } + +func TestParseUniqueConstraints(t *testing.T) { + type UserUnique struct { + Name1 string `gorm:"unique"` + Name2 string `gorm:"uniqueIndex"` + } + + user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + constraints := user.ParseUniqueConstraints() + + results := map[string]schema.UniqueConstraint{ + "uni_user_uniques_name1": { + Name: "uni_user_uniques_name1", + Field: &schema.Field{Name: "Name1", Unique: true}, + }, + } + for k, result := range results { + v, ok := constraints[k] + if !ok { + t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints) + } + tests.AssertObjEqual(t, result, v, "Name") + tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex") + } +} diff --git a/schema/field.go b/schema/field.go index 657e0a4bd..91e4c0abf 100644 --- a/schema/field.go +++ b/schema/field.go @@ -89,6 +89,12 @@ type Field struct { Set func(context.Context, reflect.Value, interface{}) error Serializer SerializerInterface NewValuePool FieldNewValuePool + + // In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable. + // When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique. + // It causes field unnecessarily migration. + // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. + UniqueIndex string } func (field *Field) BindName() string { diff --git a/schema/index.go b/schema/index.go index f5ac5dd21..f4f367510 100644 --- a/schema/index.go +++ b/schema/index.go @@ -13,8 +13,8 @@ type Index struct { Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string - Option string // WITH PARSER parser_name - Fields []IndexOption + Option string // WITH PARSER parser_name + Fields []IndexOption // Note: IndexOption's Field maybe the same } type IndexOption struct { @@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index { } for _, index := range indexes { if index.Class == "UNIQUE" && len(index.Fields) == 1 { - index.Fields[0].Field.Unique = true + index.Fields[0].Field.UniqueIndex = index.Name } } return indexes diff --git a/schema/index_test.go b/schema/index_test.go index 890327de6..2f1e36afc 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -1,11 +1,11 @@ package schema_test import ( - "reflect" "sync" "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserIndex struct { @@ -19,6 +19,7 @@ type UserIndex struct { OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` + Name8 string `gorm:"index:,length:10;index:,collate:utf8"` // Composite Index: Flattened structure. Data0A string `gorm:"index:,composite:comp_id0"` @@ -65,7 +66,7 @@ func TestParseIndex(t *testing.T) { "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", @@ -81,7 +82,7 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", @@ -102,18 +103,27 @@ func TestParseIndex(t *testing.T) { }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "type": { Name: "type", Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, + "idx_user_indices_name8": { + Name: "idx_user_indices_name8", + Type: "", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "Name8"}, Length: 10}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "Name8"}, Collate: "utf8"}, + }, + }, "idx_user_indices_comp_id0": { Name: "idx_user_indices_comp_id0", Type: "", @@ -146,40 +156,109 @@ func TestParseIndex(t *testing.T) { }, } - indices := user.ParseIndexes() + CheckIndices(t, results, user.ParseIndexes()) +} - for k, result := range results { - v, ok := indices[k] - if !ok { - t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) - } +func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { + type IndexTest struct { + FieldA string `gorm:"unique;index"` // unique and index + FieldB string `gorm:"unique"` // unique - for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { - if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { - t.Errorf( - "index %v %v should equal, expects %v, got %v", - k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), - ) - } - } + FieldC string `gorm:"index:,unique"` // uniqueIndex + FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index - for idx, ef := range result.Fields { - rf := v.Fields[idx] - if rf.Field.Name != ef.Field.Name { - t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) + FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex + FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"` + + FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index + FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"` + + FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex + + FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + } + indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index, got error %v", err) + } + indices := indexSchema.ParseIndexes() + CheckIndices(t, map[string]schema.Index{ + "idx_index_tests_field_a": { + Name: "idx_index_tests_field_a", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}}, + }, + "idx_index_tests_field_c": { + Name: "idx_index_tests_field_c", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}}, + }, + "idx_index_tests_field_d": { + Name: "idx_index_tests_field_d", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldD"}}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "FieldD"}}, + }, + }, + "uniq_field_e1_e2": { + Name: "uniq_field_e1_e2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldE1"}}, + {Field: &schema.Field{Name: "FieldE2"}}, + }, + }, + "idx_index_tests_field_f1": { + Name: "idx_index_tests_field_f1", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}}, + }, + "uniq_field_f1_f2": { + Name: "uniq_field_f1_f2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldF1"}}, + {Field: &schema.Field{Name: "FieldF2"}}, + }, + }, + "idx_index_tests_field_g": { + Name: "idx_index_tests_field_g", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}}, + }, + "uniq_field_h1_h2": { + Name: "uniq_field_h1_h2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldH1", Unique: true}}, + {Field: &schema.Field{Name: "FieldH2"}}, + }, + }, + }, indices) +} + +func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) { + for k, ei := range expected { + t.Run(k, func(t *testing.T) { + ai, ok := actual[k] + if !ok { + t.Errorf("expected index %q but actual missing", k) + return } - if rf.Field.Unique != ef.Field.Unique { - t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique) + tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option") + if len(ei.Fields) != len(ai.Fields) { + t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields)) + return } - - for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { - if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { - t.Errorf( - "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, - reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), - ) - } + for i, ef := range ei.Fields { + af := ai.Fields[i] + tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length") } - } + }) + delete(actual, k) + } + for k := range actual { + t.Errorf("unexpected index %q", k) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index a75a33c0d..306d4f4e0 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,6 +4,12 @@ import ( "gorm.io/gorm/clause" ) +// ConstraintInterface database constraint interface +type ConstraintInterface interface { + GetName() string + Build() (sql string, vars []interface{}) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string diff --git a/schema/relationship.go b/schema/relationship.go index 57167859c..2e94fc2cb 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -605,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } +// Constraint is ForeignKey Constraint type Constraint struct { Name string Field *Field @@ -616,6 +617,31 @@ type Constraint struct { OnUpdate string } +func (constraint *Constraint) GetName() string { return constraint.Name } + +func (constraint *Constraint) Build() (sql string, vars []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + references := make([]interface{}, 0, len(constraint.References)) + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" {