diff --git a/models/activities/notification.go b/models/activities/notification.go index d20a53a41d860..321d543c736ae 100644 --- a/models/activities/notification.go +++ b/models/activities/notification.go @@ -141,7 +141,7 @@ func CountNotifications(ctx context.Context, opts *FindNotificationOptions) (int // CreateRepoTransferNotification creates notification for the user a repository was transferred to func CreateRepoTransferNotification(ctx context.Context, doer, newOwner *user_model.User, repo *repo_model.Repository) error { - return db.AutoTx(ctx, func(ctx context.Context) error { + return db.WithTx(ctx, func(ctx context.Context) error { var notify []*Notification if newOwner.IsOrganization() { diff --git a/models/db/context.go b/models/db/context.go index 3db8b16528da4..455f3d1c5de74 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -71,6 +71,14 @@ type Engined interface { // GetEngine will get a db Engine from this context or return an Engine restricted to this context func GetEngine(ctx context.Context) Engine { + if e := getEngine(ctx); e != nil { + return e + } + return x.Context(ctx) +} + +// getEngine will get a db Engine from this context or return nil +func getEngine(ctx context.Context) Engine { if engined, ok := ctx.(Engined); ok { return engined.Engine() } @@ -78,7 +86,7 @@ func GetEngine(ctx context.Context) Engine { if enginedInterface != nil { return enginedInterface.(Engined).Engine() } - return x.Context(ctx) + return nil } // Committer represents an interface to Commit or Close the Context @@ -87,10 +95,22 @@ type Committer interface { Close() error } -// TxContext represents a transaction Context +// halfCommitter is a wrapper of Committer. +// It can be closed early, but can't be committed early, it is useful for reusing a transaction. +type halfCommitter struct { + Committer +} + +func (*halfCommitter) Commit() error { + // do nothing + return nil +} + +// TxContext represents a transaction Context, +// it will reuse the existing transaction in the parent context or create a new one. func TxContext(parentCtx context.Context) (*Context, Committer, error) { - if InTransaction(parentCtx) { - return nil, nil, ErrAlreadyInTransaction + if sess, ok := inTransaction(parentCtx); ok { + return newContext(parentCtx, sess, true), &halfCommitter{Committer: sess}, nil } sess := x.NewSession() @@ -102,20 +122,11 @@ func TxContext(parentCtx context.Context) (*Context, Committer, error) { return newContext(DefaultContext, sess, true), sess, nil } -// WithTx represents executing database operations on a transaction -// This function will always open a new transaction, if a transaction exist in parentCtx return an error. -func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error { - if InTransaction(parentCtx) { - return ErrAlreadyInTransaction - } - return txWithNoCheck(parentCtx, f) -} - -// AutoTx represents executing database operations on a transaction, if the transaction exist, +// WithTx represents executing database operations on a transaction, if the transaction exist, // this function will reuse it otherwise will create a new one and close it when finished. -func AutoTx(parentCtx context.Context, f func(ctx context.Context) error) error { - if InTransaction(parentCtx) { - return f(newContext(parentCtx, GetEngine(parentCtx), true)) +func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error { + if sess, ok := inTransaction(parentCtx); ok { + return f(newContext(parentCtx, sess, true)) } return txWithNoCheck(parentCtx, f) } @@ -202,25 +213,25 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) { // InTransaction returns true if the engine is in a transaction otherwise return false func InTransaction(ctx context.Context) bool { - var e Engine - if engined, ok := ctx.(Engined); ok { - e = engined.Engine() - } else { - enginedInterface := ctx.Value(enginedContextKey) - if enginedInterface != nil { - e = enginedInterface.(Engined).Engine() - } - } + _, ok := inTransaction(ctx) + return ok +} + +func inTransaction(ctx context.Context) (*xorm.Session, bool) { + e := getEngine(ctx) if e == nil { - return false + return nil, false } switch t := e.(type) { case *xorm.Engine: - return false + return nil, false case *xorm.Session: - return t.IsInTx() + if t.IsInTx() { + return t, true + } + return nil, false default: - return false + return nil, false } } diff --git a/models/db/context_test.go b/models/db/context_test.go index e7518a50d8d29..95a01d4a26eb2 100644 --- a/models/db/context_test.go +++ b/models/db/context_test.go @@ -25,8 +25,62 @@ func TestInTransaction(t *testing.T) { assert.NoError(t, err) defer committer.Close() assert.True(t, db.InTransaction(ctx)) - assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error { + assert.NoError(t, db.WithTx(ctx, func(ctx context.Context) error { assert.True(t, db.InTransaction(ctx)) return nil })) } + +func TestTxContext(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + { // create new transaction + ctx, committer, err := db.TxContext(db.DefaultContext) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.NoError(t, committer.Commit()) + } + + { // reuse the transaction created by TxContext and commit it + ctx, committer, err := db.TxContext(db.DefaultContext) + engine := db.GetEngine(ctx) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.Equal(t, engine, db.GetEngine(ctx)) + assert.NoError(t, committer.Commit()) + } + assert.NoError(t, committer.Commit()) + } + + { // reuse the transaction created by TxContext and close it + ctx, committer, err := db.TxContext(db.DefaultContext) + engine := db.GetEngine(ctx) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.Equal(t, engine, db.GetEngine(ctx)) + assert.NoError(t, committer.Close()) + } + assert.NoError(t, committer.Close()) + } + + { // reuse the transaction created by WithTx + assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error { + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + assert.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.NoError(t, committer.Commit()) + } + return nil + })) + } +} diff --git a/models/db/error.go b/models/db/error.go index 5860cb4a071d1..edc8e80a9c605 100644 --- a/models/db/error.go +++ b/models/db/error.go @@ -4,14 +4,11 @@ package db import ( - "errors" "fmt" "code.gitea.io/gitea/modules/util" ) -var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction") - // ErrCancelled represents an error due to context cancellation type ErrCancelled struct { Message string diff --git a/models/issues/issue.go b/models/issues/issue.go index f45e635c0ecea..417d6a1557892 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -2365,7 +2365,7 @@ func CountOrphanedIssues(ctx context.Context) (int64, error) { // DeleteOrphanedIssues delete issues without a repo func DeleteOrphanedIssues(ctx context.Context) error { var attachmentPaths []string - err := db.AutoTx(ctx, func(ctx context.Context) error { + err := db.WithTx(ctx, func(ctx context.Context) error { var ids []int64 if err := db.GetEngine(ctx).Table("issue").Distinct("issue.repo_id"). diff --git a/models/project/project.go b/models/project/project.go index 0a07cfe22ad1b..f432d0bc4c5bd 100644 --- a/models/project/project.go +++ b/models/project/project.go @@ -300,7 +300,7 @@ func changeProjectStatus(ctx context.Context, p *Project, isClosed bool) error { // DeleteProjectByID deletes a project from a repository. if it's not in a database // transaction, it will start a new database transaction func DeleteProjectByID(ctx context.Context, id int64) error { - return db.AutoTx(ctx, func(ctx context.Context) error { + return db.WithTx(ctx, func(ctx context.Context) error { p, err := GetProjectByID(ctx, id) if err != nil { if IsErrProjectNotExist(err) { diff --git a/models/repo/collaboration.go b/models/repo/collaboration.go index 29bcab70f36a3..7989e5bdf9c5d 100644 --- a/models/repo/collaboration.go +++ b/models/repo/collaboration.go @@ -105,7 +105,7 @@ func ChangeCollaborationAccessMode(ctx context.Context, repo *Repository, uid in return nil } - return db.AutoTx(ctx, func(ctx context.Context) error { + return db.WithTx(ctx, func(ctx context.Context) error { e := db.GetEngine(ctx) collaboration := &Collaboration{ diff --git a/models/repo_transfer.go b/models/repo_transfer.go index 19e6c0662727f..27a77f9b8cc9f 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -155,7 +155,7 @@ func TestRepositoryReadyForTransfer(status repo_model.RepositoryStatus) error { // CreatePendingRepositoryTransfer transfer a repo from one owner to a new one. // it marks the repository transfer as "pending" func CreatePendingRepositoryTransfer(ctx context.Context, doer, newOwner *user_model.User, repoID int64, teams []*organization.Team) error { - return db.AutoTx(ctx, func(ctx context.Context) error { + return db.WithTx(ctx, func(ctx context.Context) error { repo, err := repo_model.GetRepositoryByID(ctx, repoID) if err != nil { return err diff --git a/modules/notification/ui/ui.go b/modules/notification/ui/ui.go index 63a3ffd199989..bc66c3d5a3db4 100644 --- a/modules/notification/ui/ui.go +++ b/modules/notification/ui/ui.go @@ -243,7 +243,7 @@ func (ns *notificationService) NotifyPullReviewRequest(ctx context.Context, doer } func (ns *notificationService) NotifyRepoPendingTransfer(ctx context.Context, doer, newOwner *user_model.User, repo *repo_model.Repository) { - err := db.AutoTx(ctx, func(ctx context.Context) error { + err := db.WithTx(ctx, func(ctx context.Context) error { return activities_model.CreateRepoTransferNotification(ctx, doer, newOwner, repo) }) if err != nil { diff --git a/modules/repository/collaborator.go b/modules/repository/collaborator.go index f2b95151879ff..f5cdc35045353 100644 --- a/modules/repository/collaborator.go +++ b/modules/repository/collaborator.go @@ -14,7 +14,7 @@ import ( ) func AddCollaborator(ctx context.Context, repo *repo_model.Repository, u *user_model.User) error { - return db.AutoTx(ctx, func(ctx context.Context) error { + return db.WithTx(ctx, func(ctx context.Context) error { collaboration := &repo_model.Collaboration{ RepoID: repo.ID, UserID: u.ID, diff --git a/services/issue/comments.go b/services/issue/comments.go index 46c0daf70a336..1323fb47aa440 100644 --- a/services/issue/comments.go +++ b/services/issue/comments.go @@ -123,7 +123,7 @@ func UpdateComment(ctx context.Context, c *issues_model.Comment, doer *user_mode // DeleteComment deletes the comment func DeleteComment(ctx context.Context, doer *user_model.User, comment *issues_model.Comment) error { - err := db.AutoTx(ctx, func(ctx context.Context) error { + err := db.WithTx(ctx, func(ctx context.Context) error { return issues_model.DeleteComment(ctx, comment) }) if err != nil {