diff --git a/models/db/common.go b/models/db/common.go index af6130c9f255b..aa529ef87e8dd 100644 --- a/models/db/common.go +++ b/models/db/common.go @@ -4,12 +4,20 @@ package db import ( + "context" + "fmt" + "reflect" "strings" + "time" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" "xorm.io/builder" + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" ) // BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively. @@ -20,3 +28,409 @@ func BuildCaseInsensitiveLike(key, value string) builder.Cond { } return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)} } + +// InsertOnConflictDoNothing will attempt to insert the provided bean but if there is a conflict it will not error out +// This function will update the ID of the provided bean if there is an insertion +// This does not do all of the conversions that xorm would do automatically but it does quite a number of them +// once xorm has a working InsertOnConflictDoNothing this function could be removed. +func InsertOnConflictDoNothing(ctx context.Context, bean interface{}) (bool, error) { + e := GetEngine(ctx) + + tableName := x.TableName(bean, true) + table, err := x.TableInfo(bean) + if err != nil { + return false, err + } + + autoIncrCol := table.AutoIncrColumn() + + columns := table.Columns() + + colNames, values, zeroedColNames, zeroedValues, err := getColNamesAndValuesFromBean(bean, columns) + if err != nil { + return false, err + } + + if len(colNames) == 0 { + return false, fmt.Errorf("provided bean to insert has all empty values") + } + + // MSSQL needs to separately pass in the columns with the unique constraint and we need to + // include empty columns which are in the constraint in the insert for other dbs + uniqueColValMap, colNames, values := addInUniqueCols(colNames, values, zeroedColNames, zeroedValues, table) + if len(uniqueColValMap) == 0 { + return false, fmt.Errorf("provided bean has no unique constraints") + } + + var insertArgs []any + + switch { + case setting.Database.Type.IsSQLite3(): + insertArgs = generateInsertNoConflictSQLAndArgsForSQLite(tableName, colNames, values) + case setting.Database.Type.IsPostgreSQL(): + insertArgs = generateInsertNoConflictSQLAndArgsForPostgres(tableName, colNames, values, autoIncrCol) + case setting.Database.Type.IsMySQL(): + insertArgs = generateInsertNoConflictSQLAndArgsForMySQL(tableName, colNames, values) + case setting.Database.Type.IsMSSQL(): + insertArgs = generateInsertNoConflictSQLAndArgsForMSSQL(table, tableName, colNames, values, uniqueColValMap, autoIncrCol) + default: + return false, fmt.Errorf("database type not supported") + } + + if autoIncrCol != nil && (setting.Database.Type.IsPostgreSQL() || setting.Database.Type.IsMSSQL()) { + // Postgres and MSSQL do not use the LastInsertID mechanism + // Therefore use query rather than exec and read the last provided ID back in + + res, err := e.Query(insertArgs...) + if err != nil { + return false, fmt.Errorf("error in query: %s, %w", insertArgs[0], err) + } + if len(res) == 0 { + // this implies there was a conflict + return false, nil + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + log.Error("unable to get value for autoincrcol of %#v %v", bean, err) + } + + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { + return true, nil + } + + id := res[0][autoIncrCol.Name] + err = convert.AssignValue(*aiValue, id) + if err != nil { + return true, fmt.Errorf("error in assignvalue %v %v %w", id, res, err) + } + return true, nil + } + + res, err := e.Exec(insertArgs...) + if err != nil { + return false, err + } + + n, err := res.RowsAffected() + if err != nil { + return n != 0, err + } + + if n != 0 && autoIncrCol != nil { + id, err := res.LastInsertId() + if err != nil { + return true, err + } + reflect.ValueOf(bean).Elem().FieldByName(autoIncrCol.FieldName).SetInt(id) + } + + return n != 0, err +} + +// generateInsertNoConflictSQLAndArgsForSQLite will create the correct insert code for SQLite +func generateInsertNoConflictSQLAndArgsForSQLite(tableName string, colNames []string, args []any) (insertArgs []any) { + sb := &strings.Builder{} + + quote := x.Dialect().Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = sb.WriteString(arg) + } + } + write("INSERT INTO ", quote(tableName), " (") + _ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") + write(") VALUES (?") + for range colNames[1:] { + write(",?") + } + write(") ON CONFLICT DO NOTHING") + args[0] = sb.String() + return args +} + +// generateInsertNoConflictSQLAndArgsForPostgres will create the correct insert code for Postgres +func generateInsertNoConflictSQLAndArgsForPostgres(tableName string, colNames []string, args []any, autoIncrCol *schemas.Column) (insertArgs []any) { + sb := &strings.Builder{} + + quote := x.Dialect().Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = sb.WriteString(arg) + } + } + write("INSERT INTO ", quote(tableName), " (") + _ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") + write(") VALUES (?") + for range colNames[1:] { + write(",?") + } + write(") ON CONFLICT DO NOTHING") + if autoIncrCol != nil { + write(" RETURNING ", quote(autoIncrCol.Name)) + } + args[0] = sb.String() + return args +} + +// generateInsertNoConflictSQLAndArgsForMySQL will create the correct insert code for MySQL +func generateInsertNoConflictSQLAndArgsForMySQL(tableName string, colNames []string, args []any) (insertArgs []any) { + sb := &strings.Builder{} + + quote := x.Dialect().Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = sb.WriteString(arg) + } + } + write("INSERT IGNORE INTO ", quote(tableName), " (") + _ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") + write(") VALUES (?") + for range colNames[1:] { + write(",?") + } + write(")") + args[0] = sb.String() + return args +} + +// generateInsertNoConflictSQLAndArgsForMSSQL writes the INSERT ... ON CONFLICT sql variant for MSSQL +// MSSQL uses MERGE WITH ... but needs to pre-select the unique cols first +// then WHEN NOT MATCHED INSERT - this is kind of the opposite way round from INSERT ... ON CONFLICT +func generateInsertNoConflictSQLAndArgsForMSSQL(table *schemas.Table, tableName string, colNames []string, args []any, uniqueColValMap map[string]any, autoIncrCol *schemas.Column) (insertArgs []any) { + sb := &strings.Builder{} + + quote := x.Dialect().Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = sb.WriteString(arg) + } + } + uniqueCols := make([]string, 0, len(uniqueColValMap)) + for colName := range uniqueColValMap { + uniqueCols = append(uniqueCols, colName) + } + + write("MERGE ", quote(tableName), " WITH (HOLDLOCK) AS target USING (SELECT ? AS ") + _ = x.Dialect().Quoter().JoinWrite(sb, uniqueCols, ", ? AS ") + write(") AS src ON (") + countUniques := 0 + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + if countUniques > 0 { + write(" OR ") + } + countUniques++ + write("(") + write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0])) + for _, col := range index.Cols[1:] { + write(" AND src.", quote(col), "= target.", quote(col)) + } + write(")") + } + write(") WHEN NOT MATCHED THEN INSERT (") + _ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") + write(") VALUES (?") + for range colNames[1:] { + write(", ?") + } + write(")") + if autoIncrCol != nil { + write(" OUTPUT INSERTED.", quote(autoIncrCol.Name)) + } + write(";") + uniqueArgs := make([]any, 0, len(uniqueColValMap)+len(args)) + uniqueArgs = append(uniqueArgs, sb.String()) + for _, col := range uniqueCols { + uniqueArgs = append(uniqueArgs, uniqueColValMap[col]) + } + return append(uniqueArgs, args[1:]...) +} + +// addInUniqueCols determines the columns that refer to unique constraints and creates slices for these +// as they're needed by MSSQL. In addition, any columns which are zero-valued but are part of a constraint +// are added back in to the colNames and args +func addInUniqueCols(colNames []string, args []any, zeroedColNames []string, emptyArgs []any, table *schemas.Table) (uniqueColValMap map[string]any, insertCols []string, insertArgs []any) { + uniqueColValMap = make(map[string]any) + + // Iterate across the indexes in the provided table + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + + // index is a Unique constraint + indexCol: + for _, iCol := range index.Cols { + if _, has := uniqueColValMap[iCol]; has { + // column is already included in uniqueCols so we don't need to add it again + continue indexCol + } + + // Now iterate across colNames and add to the uniqueCols + for i, col := range colNames { + if col == iCol { + uniqueColValMap[col] = args[i+1] + continue indexCol + } + } + + // If we still haven't found the column we need to look in the emptyColumns and add + // it back into colNames and args as well as uniqueCols/uniqueArgs + for i, col := range zeroedColNames { + if col == iCol { + // Always include empty unique columns in the insert statement as otherwise the insert no conflict will pass + colNames = append(colNames, col) + args = append(args, emptyArgs[i]) + uniqueColValMap[col] = emptyArgs[i] + continue indexCol + } + } + } + } + return uniqueColValMap, colNames, args +} + +// getColNamesAndValuesFromBean reads the provided bean, providing two pairs of linked slices: +// +// - colNames and values +// - zeroedColNames and zeroedValues +// +// colNames contains the names of the columns that have non-zero values in the provided bean +// values contains the values - with one exception - values is 1-based so that values[0] is deliberately left zero +// +// emptyyColNames and zeroedValues accounts for the other columns - with zeroedValues containing the zero values +func getColNamesAndValuesFromBean(bean interface{}, cols []*schemas.Column) (colNames []string, values []any, zeroedColNames []string, zeroedValues []any, err error) { + colNames = make([]string, len(cols)) + values = make([]any, len(cols)+1) // Leave args[0] to put the SQL in + maxNonEmpty := 0 + minEmpty := len(cols) + + val := reflect.ValueOf(bean) + elem := val.Elem() + for _, col := range cols { + if fieldIdx := col.FieldIndex; fieldIdx != nil { + fieldVal := elem.FieldByIndex(fieldIdx) + if col.IsCreated || col.IsUpdated { + result, err := setCurrentTime(fieldVal, col) + if err != nil { + return nil, nil, nil, nil, err + } + + colNames[maxNonEmpty] = col.Name + maxNonEmpty++ + values[maxNonEmpty] = result + continue + } + + val, err := getValueFromField(fieldVal, col) + if err != nil { + return nil, nil, nil, nil, err + } + if fieldVal.IsZero() { + values[minEmpty] = val // remember args is 1-based not 0-based + minEmpty-- + colNames[minEmpty] = col.Name + continue + } + colNames[maxNonEmpty] = col.Name + maxNonEmpty++ + values[maxNonEmpty] = val + } + } + + return colNames[:maxNonEmpty], values[:maxNonEmpty+1], colNames[maxNonEmpty:], values[maxNonEmpty+1:], nil +} + +func setCurrentTime(fieldVal reflect.Value, col *schemas.Column) (interface{}, error) { + t := time.Now() + result, err := dialects.FormatColumnTime(x.Dialect(), x.DatabaseTZ, col, t) + if err != nil { + return result, err + } + + switch fieldVal.Type().Kind() { + case reflect.Struct: + fieldVal.Set(reflect.ValueOf(t).Convert(fieldVal.Type())) + case reflect.Int, reflect.Int64, reflect.Int32: + fieldVal.SetInt(t.Unix()) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + fieldVal.SetUint(uint64(t.Unix())) + } + return result, nil +} + +// getValueFromField extracts the reflected value from the provided fieldVal +// this keeps the type and makes such that zero values work in the SQL Insert above +func getValueFromField(fieldVal reflect.Value, col *schemas.Column) (any, error) { + // Handle pointers to convert.Conversion + if fieldVal.CanAddr() { + if fieldConvert, ok := fieldVal.Addr().Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + } + + // Handle nil pointer to convert.Conversion + isNil := fieldVal.Kind() == reflect.Ptr && fieldVal.IsNil() + if !isNil { + if fieldConvert, ok := fieldVal.Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + } + + // Handle common primitive types + switch fieldVal.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return fieldVal.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return fieldVal.Uint(), nil + case reflect.Float32, reflect.Float64: + return fieldVal.Float(), nil + case reflect.Complex64, reflect.Complex128: + return fieldVal.Complex(), nil + case reflect.String: + return fieldVal.String(), nil + case reflect.Bool: + valBool := fieldVal.Bool() + + if setting.Database.Type.IsMSSQL() { + if valBool { + return 1, nil + } + return 0, nil + } + return valBool, nil + default: + } + + // just return the interface + return fieldVal.Interface(), nil +} diff --git a/models/packages/package.go b/models/packages/package.go index 32f30fab9b407..7f84c5053785a 100644 --- a/models/packages/package.go +++ b/models/packages/package.go @@ -152,7 +152,10 @@ type Package struct { // TryInsertPackage inserts a package. If a package exists already, ErrDuplicatePackage is returned func TryInsertPackage(ctx context.Context, p *Package) (*Package, error) { - e := db.GetEngine(ctx) + inserted, err := db.InsertOnConflictDoNothing(ctx, p) + if err != nil || inserted { + return p, err + } key := &Package{ OwnerID: p.OwnerID, @@ -160,17 +163,16 @@ func TryInsertPackage(ctx context.Context, p *Package) (*Package, error) { LowerName: p.LowerName, } - has, err := e.Get(key) - if err != nil { - return nil, err - } + has, err := db.GetEngine(ctx).Get(key) if has { return key, ErrDuplicatePackage + } else if err != nil { + return key, err } - if _, err = e.Insert(p); err != nil { - return nil, err - } - return p, nil + // This really should never happen and can only happen if this function + // is being called outside of a transaction and between the on conflict insert failing + // the conlicting item is removed. + return p, fmt.Errorf("unable to insert on conflict but yet not able to get from the db") } // DeletePackageByID deletes a package by id diff --git a/models/packages/package_file.go b/models/packages/package_file.go index 97e7a0d4070a9..2f133989c3a20 100644 --- a/models/packages/package_file.go +++ b/models/packages/package_file.go @@ -5,6 +5,7 @@ package packages import ( "context" + "fmt" "strconv" "strings" "time" @@ -44,25 +45,27 @@ type PackageFile struct { // TryInsertFile inserts a file. If the file exists already ErrDuplicatePackageFile is returned func TryInsertFile(ctx context.Context, pf *PackageFile) (*PackageFile, error) { - e := db.GetEngine(ctx) + inserted, err := db.InsertOnConflictDoNothing(ctx, pf) + if err != nil || inserted { + return pf, err + } key := &PackageFile{ VersionID: pf.VersionID, LowerName: pf.LowerName, CompositeKey: pf.CompositeKey, } - - has, err := e.Get(key) - if err != nil { - return nil, err - } + has, err := db.GetEngine(ctx).Get(key) if has { - return pf, ErrDuplicatePackageFile + return key, ErrDuplicatePackageFile } - if _, err = e.Insert(pf); err != nil { - return nil, err + if err != nil { + return key, err } - return pf, nil + // This really should never happen and can only happen if this function + // is being called outside of a transaction and between the on conflict insert failing + // the conlicting item is removed. + return pf, fmt.Errorf("unable to insert on conflict but yet not able to get from the db") } // GetFilesByVersionID gets all files of a version diff --git a/models/packages/package_test.go b/models/packages/package_test.go index 735688a731e0e..45c2a7fe17a03 100644 --- a/models/packages/package_test.go +++ b/models/packages/package_test.go @@ -30,6 +30,7 @@ func TestHasOwnerPackages(t *testing.T) { p, err := packages_model.TryInsertPackage(db.DefaultContext, &packages_model.Package{ OwnerID: owner.ID, + Name: "package", LowerName: "package", }) assert.NotNil(t, p) @@ -42,6 +43,7 @@ func TestHasOwnerPackages(t *testing.T) { pv, err := packages_model.GetOrInsertVersion(db.DefaultContext, &packages_model.PackageVersion{ PackageID: p.ID, + Version: "internal", LowerVersion: "internal", IsInternal: true, }) @@ -55,6 +57,7 @@ func TestHasOwnerPackages(t *testing.T) { pv, err = packages_model.GetOrInsertVersion(db.DefaultContext, &packages_model.PackageVersion{ PackageID: p.ID, + Version: "normal", LowerVersion: "normal", IsInternal: false, }) diff --git a/models/packages/package_version.go b/models/packages/package_version.go index 759c20abed22b..e8ca8ed8ae63e 100644 --- a/models/packages/package_version.go +++ b/models/packages/package_version.go @@ -5,6 +5,7 @@ package packages import ( "context" + "fmt" "strconv" "strings" @@ -37,24 +38,27 @@ type PackageVersion struct { // GetOrInsertVersion inserts a version. If the same version exist already ErrDuplicatePackageVersion is returned func GetOrInsertVersion(ctx context.Context, pv *PackageVersion) (*PackageVersion, error) { - e := db.GetEngine(ctx) + inserted, err := db.InsertOnConflictDoNothing(ctx, pv) + if err != nil || inserted { + return pv, err + } key := &PackageVersion{ PackageID: pv.PackageID, LowerVersion: pv.LowerVersion, } - has, err := e.Get(key) - if err != nil { - return nil, err - } + has, err := db.GetEngine(ctx).Get(key) if has { return key, ErrDuplicatePackageVersion + } else if err != nil { + return key, err } - if _, err = e.Insert(pv); err != nil { - return nil, err - } - return pv, nil + // This really should never happen and can only happen if this function + // is being called outside of a transaction and between the on conflict insert failing + // the conlicting item is removed. + // + return pv, fmt.Errorf("unable to insert on conflict but yet not able to get from the db") } // UpdateVersion updates a version diff --git a/routers/api/packages/container/blob.go b/routers/api/packages/container/blob.go index f0457c55e19c3..ba969778246c0 100644 --- a/routers/api/packages/container/blob.go +++ b/routers/api/packages/container/blob.go @@ -10,7 +10,6 @@ import ( "fmt" "os" "strings" - "sync" "code.gitea.io/gitea/models/db" packages_model "code.gitea.io/gitea/models/packages" @@ -22,8 +21,6 @@ import ( packages_service "code.gitea.io/gitea/services/packages" ) -var uploadVersionMutex sync.Mutex - // saveAsPackageBlob creates a package blob from an upload // The uploaded blob gets stored in a special upload version to link them to the package/image func saveAsPackageBlob(hsr packages_module.HashedSizeReader, pci *packages_service.PackageCreationInfo) (*packages_model.PackageBlob, error) { @@ -93,9 +90,6 @@ func mountBlob(pi *packages_service.PackageInfo, pb *packages_model.PackageBlob) func getOrCreateUploadVersion(pi *packages_service.PackageInfo) (*packages_model.PackageVersion, error) { var uploadVersion *packages_model.PackageVersion - // FIXME: Replace usage of mutex with database transaction - // https://github.com/go-gitea/gitea/pull/21862 - uploadVersionMutex.Lock() err := db.WithTx(db.DefaultContext, func(ctx context.Context) error { created := true p := &packages_model.Package{ @@ -140,7 +134,6 @@ func getOrCreateUploadVersion(pi *packages_service.PackageInfo) (*packages_model return nil }) - uploadVersionMutex.Unlock() return uploadVersion, err } diff --git a/tests/integration/db_common_test.go b/tests/integration/db_common_test.go new file mode 100644 index 0000000000000..0b4c22aa9d6d6 --- /dev/null +++ b/tests/integration/db_common_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package integration + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/tests" + + "github.com/stretchr/testify/assert" +) + +func TestInsertOnConflictDoNothing(t *testing.T) { + defer tests.PrepareTestEnv(t)() + + ctx := db.DefaultContext + e := db.GetEngine(ctx) + t.Run("NoUnique", func(t *testing.T) { + type NoUniques struct { + ID int64 `xorm:"pk autoincr"` + Data string + } + _ = e.Sync2(&NoUniques{}) + + // InsertOnConflictDoNothing does not work if there is no unique constraint + toInsert := &NoUniques{Data: "shouldErr"} + inserted, err := db.InsertOnConflictDoNothing(ctx, toInsert) + assert.Error(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + + // InsertOnConflictDoNothing does not work if there is no unique constraint + toInsert = &NoUniques{Data: ""} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.Error(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("OneUnique", func(t *testing.T) { + type OneUnique struct { + ID int64 `xorm:"pk autoincr"` + Data string `xorm:"UNIQUE NOT NULL"` + } + + _ = e.Sync2(&OneUnique{}) + _, _ = e.Exec("DELETE FROM one_unique") + + // Cannot insert if the unique field is NULL + toInsert := &OneUnique{} + inserted, err := db.InsertOnConflictDoNothing(ctx, toInsert) + assert.Error(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + + // Successfully insert test + toInsert = &OneUnique{Data: "test"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.True(t, inserted) + assert.NotEqual(t, int64(0), toInsert.ID) + + // Successfully insert test2 + toInsert = &OneUnique{Data: "test2"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.True(t, inserted) + assert.NotEqual(t, int64(0), toInsert.ID) + + // Successfully not insert test + toInsert = &OneUnique{Data: "test"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("MultiUnique", func(t *testing.T) { + type MultiUnique struct { + ID int64 `xorm:"pk autoincr"` + NotUnique string + Data1 string `xorm:"UNIQUE(s) NOT NULL"` + Data2 string `xorm:"UNIQUE(s) NOT NULL"` + } + + _ = e.Sync2(&MultiUnique{}) + _, _ = e.Exec("DELETE FROM multi_unique") + + // Cannot insert if the unique fields are null + toInsert := &MultiUnique{} + inserted, err := db.InsertOnConflictDoNothing(ctx, toInsert) + assert.Error(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + + // successfully insert test, t1 + toInsert = &MultiUnique{Data1: "test", NotUnique: "t1"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.True(t, inserted) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully insert test2, t1 + toInsert = &MultiUnique{Data1: "test2", NotUnique: "t1"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.True(t, inserted) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully don't insert test2, t2 + toInsert = &MultiUnique{Data1: "test2", NotUnique: "t2"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + + // successfully don't insert test, t2 + toInsert = &MultiUnique{Data1: "test", NotUnique: "t2"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + + // successfully insert test/test2, t2 + toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t1"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.True(t, inserted) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully don't insert test/test2, t2 + toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t2"} + inserted, err = db.InsertOnConflictDoNothing(ctx, toInsert) + assert.NoError(t, err) + assert.False(t, inserted) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("MultiMultiUnique", func(t *testing.T) { + type MultiMultiUnique struct { + ID int64 `xorm:"pk autoincr"` + Data0 string `xorm:"UNIQUE NOT NULL"` + Data1 string `xorm:"UNIQUE(s) NOT NULL"` + Data2 string `xorm:"UNIQUE(s) NOT NULL"` + } + + _ = e.Sync2(&MultiMultiUnique{}) + _, _ = e.Exec("DELETE FROM multi_multi_unique") + + inserted, err := db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{}) + assert.Error(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data1: "test", Data0: "t1"}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data2: "test2", Data0: "t1"}) + assert.NoError(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data1: "test", Data0: "t2"}) + assert.NoError(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t3"}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.False(t, inserted) + }) + + t.Run("NoPK", func(t *testing.T) { + type NoPrimaryKey struct { + NotID int64 + Uniqued string `xorm:"UNIQUE"` + } + + err := e.Sync2(&NoPrimaryKey{}) + assert.NoError(t, err) + _, _ = e.Exec("DELETE FROM no_primary_key") + + empty := &NoPrimaryKey{} + inserted, err := db.InsertOnConflictDoNothing(ctx, empty) + assert.Error(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &NoPrimaryKey{Uniqued: "1"}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &NoPrimaryKey{NotID: 1}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &NoPrimaryKey{NotID: 2}) + assert.NoError(t, err) + assert.False(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &NoPrimaryKey{NotID: 2, Uniqued: "2"}) + assert.NoError(t, err) + assert.True(t, inserted) + + inserted, err = db.InsertOnConflictDoNothing(ctx, &NoPrimaryKey{NotID: 1, Uniqued: "2"}) + assert.NoError(t, err) + assert.False(t, inserted) + }) +}