From e6a9d31199da2473bd40c193d71dca9dd23b0171 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 23 Mar 2023 15:11:46 +0100 Subject: [PATCH 1/2] feat(GODT-2500): Add panic handlers everywhere. --- builder.go | 18 +++++++++++++----- internal/backend/backend.go | 7 +++++-- internal/backend/update_injector.go | 5 +++-- internal/backend/user.go | 11 +++++++++-- internal/session/command.go | 2 +- internal/session/handle.go | 6 ++++++ internal/session/handle_idle.go | 2 +- internal/session/session.go | 13 ++++++++++++- internal/state/mailbox_fetch.go | 2 ++ internal/state/mailbox_search.go | 2 ++ internal/state/state.go | 13 +++++++++++-- logging/logging.go | 16 ++++++++++++++-- option.go | 13 +++++++++++++ queue/queued_channel.go | 12 ++++++++++-- queue/queued_channel_test.go | 4 ++-- server.go | 20 +++++++++++++++----- store/semaphore.go | 18 ++++++++++++++++++ tests/full_state_test.go | 5 +++-- wait/wg.go | 19 ++++++++++++++++++- watcher/watcher.go | 4 ++-- watcher/watcher_test.go | 2 ++ 21 files changed, 162 insertions(+), 32 deletions(-) diff --git a/builder.go b/builder.go index 76a8b665..4a6296a2 100644 --- a/builder.go +++ b/builder.go @@ -2,14 +2,13 @@ package gluon import ( "crypto/tls" - "github.com/ProtonMail/gluon/internal/db" - "github.com/sirupsen/logrus" "io" "os" "time" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" + "github.com/ProtonMail/gluon/internal/db" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" @@ -17,6 +16,7 @@ import ( "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" + "github.com/sirupsen/logrus" ) type serverBuilder struct { @@ -35,6 +35,7 @@ type serverBuilder struct { disableParallelism bool imapLimits limits.IMAP uidValidityGenerator imap.UIDValidityGenerator + panicHandler queue.PanicHandler } func newBuilder() (*serverBuilder, error) { @@ -46,6 +47,7 @@ func newBuilder() (*serverBuilder, error) { idleBulkTime: 500 * time.Millisecond, imapLimits: limits.DefaultLimits(), uidValidityGenerator: imap.DefaultEpochUIDValidityGenerator(), + panicHandler: queue.NoopPanicHandler{}, }, nil } @@ -83,6 +85,7 @@ func (builder *serverBuilder) build() (*Server, error) { builder.delim, builder.loginJailTime, builder.imapLimits, + builder.panicHandler, ) if err != nil { return nil, err @@ -94,12 +97,12 @@ func (builder *serverBuilder) build() (*Server, error) { logrus.WithError(err).Error("Failed to remove old database files") } - return &Server{ + s := &Server{ dataDir: builder.dataDir, databaseDir: builder.databaseDir, backend: backend, sessions: make(map[int]*session.Session), - serveErrCh: queue.NewQueuedChannel[error](1, 1), + serveErrCh: queue.NewQueuedChannel[error](1, 1, builder.panicHandler), serveDoneCh: make(chan struct{}), inLogger: builder.inLogger, outLogger: builder.outLogger, @@ -111,5 +114,10 @@ func (builder *serverBuilder) build() (*Server, error) { reporter: builder.reporter, disableParallelism: builder.disableParallelism, uidValidityGenerator: builder.uidValidityGenerator, - }, nil + panicHandler: builder.panicHandler, + } + + s.serveWG.SetPanicHandler(builder.panicHandler) + + return s, nil } diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 2d20b25b..6a8bd4fa 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -14,6 +14,7 @@ import ( "github.com/ProtonMail/gluon/internal/db/ent/mailbox" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/google/uuid" @@ -49,9 +50,11 @@ type Backend struct { loginWG sync.WaitGroup imapLimits limits.IMAP + + panicHandler queue.PanicHandler } -func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP) (*Backend, error) { +func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP, panicHandler queue.PanicHandler) (*Backend, error) { return &Backend{ dataDir: dataDir, databaseDir: databaseDir, @@ -91,7 +94,7 @@ func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Con return false, err } - user, err := newUser(ctx, userID, db, conn, storeBuilder, b.delim, b.imapLimits, uidValidityGenerator) + user, err := newUser(ctx, userID, db, conn, storeBuilder, b.delim, b.imapLimits, uidValidityGenerator, b.panicHandler) if err != nil { return false, err } diff --git a/internal/backend/update_injector.go b/internal/backend/update_injector.go index d8cb1f81..5e6851be 100644 --- a/internal/backend/update_injector.go +++ b/internal/backend/update_injector.go @@ -7,6 +7,7 @@ import ( "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/logging" + "github.com/ProtonMail/gluon/queue" ) // updateInjector allows anyone to publish custom imap updates alongside the updates that are generated from the @@ -23,7 +24,7 @@ type updateInjector struct { // newUpdateInjector creates a new updateInjector. // // nolint:contextcheck -func newUpdateInjector(connector connector.Connector, userID string) *updateInjector { +func newUpdateInjector(connector connector.Connector, userID string, panicHandler queue.PanicHandler) *updateInjector { injector := &updateInjector{ updatesCh: make(chan imap.Update), forwardQuitCh: make(chan struct{}), @@ -31,7 +32,7 @@ func newUpdateInjector(connector connector.Connector, userID string) *updateInje injector.forwardWG.Add(1) - logging.GoAnnotated(context.Background(), func(ctx context.Context) { + logging.GoAnnotated(context.Background(), panicHandler, func(ctx context.Context) { injector.forward(ctx, connector.GetUpdates()) }, logging.Labels{ "Action": "Forwarding updates", diff --git a/internal/backend/user.go b/internal/backend/user.go index b3662ef2..0bd3fa2a 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -13,6 +13,7 @@ import ( "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/logging" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/bradenaw/juniper/xslices" @@ -42,6 +43,8 @@ type user struct { imapLimits limits.IMAP uidValidityGenerator imap.UIDValidityGenerator + + panicHandler queue.PanicHandler } func newUser( @@ -53,6 +56,7 @@ func newUser( delimiter string, imapLimits limits.IMAP, uidValidityGenerator imap.UIDValidityGenerator, + panicHandler queue.PanicHandler, ) (*user, error) { if err := database.Init(ctx); err != nil { return nil, err @@ -84,7 +88,7 @@ func newUser( userID: userID, connector: conn, - updateInjector: newUpdateInjector(conn, userID), + updateInjector: newUpdateInjector(conn, userID, panicHandler), store: store.NewWriteControlledStore(st), delimiter: delimiter, @@ -98,6 +102,8 @@ func newUser( imapLimits: imapLimits, uidValidityGenerator: uidValidityGenerator, + + panicHandler: panicHandler, } if err := user.deleteAllMessagesMarkedDeleted(ctx); err != nil { @@ -119,7 +125,7 @@ func newUser( user.updateWG.Add(1) // nolint:contextcheck - logging.GoAnnotated(context.Background(), func(ctx context.Context) { + logging.GoAnnotated(context.Background(), panicHandler, func(ctx context.Context) { defer user.updateWG.Done() updateCh := user.updateInjector.GetUpdates() @@ -227,6 +233,7 @@ func (user *user) newState() (*state.State, error) { newStateUserInterfaceImpl(user, newStateConnectorImpl(user)), user.delimiter, user.imapLimits, + user.panicHandler, ) user.states[newState.StateID] = newState diff --git a/internal/session/command.go b/internal/session/command.go index a52f9e45..0d907aa3 100644 --- a/internal/session/command.go +++ b/internal/session/command.go @@ -21,7 +21,7 @@ type commandResult struct { func (s *Session) startCommandReader(ctx context.Context) <-chan commandResult { cmdCh := make(chan commandResult) - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, s.panicHandler, func(ctx context.Context) { defer close(cmdCh) tlsHeaders := [][]byte{ diff --git a/internal/session/handle.go b/internal/session/handle.go index 6935fc03..b115802a 100644 --- a/internal/session/handle.go +++ b/internal/session/handle.go @@ -10,6 +10,12 @@ import ( "github.com/ProtonMail/gluon/logging" ) +func (s *Session) handlePanic() { + if s.panicHandler != nil { + s.panicHandler.HandlePanic() + } +} + func (s *Session) handleOther( ctx context.Context, tag string, diff --git a/internal/session/handle_idle.go b/internal/session/handle_idle.go index f4ae0f4e..daaa6a01 100644 --- a/internal/session/handle_idle.go +++ b/internal/session/handle_idle.go @@ -22,7 +22,7 @@ func (s *Session) handleIdle(ctx context.Context, tag string, _ *command.Idle, c } return s.state.Idle(ctx, func(pending []response.Response, resCh chan response.Response) error { - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, s.panicHandler, func(ctx context.Context) { if s.idleBulkTime != 0 { sendResponsesInBulks(s, resCh, s.idleBulkTime) } else { diff --git a/internal/session/session.go b/internal/session/session.go index b8c1611e..512b267a 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/rfcparser" "github.com/ProtonMail/gluon/version" @@ -88,6 +89,8 @@ type Session struct { errorCount int imapLimits limits.IMAP + + panicHandler queue.PanicHandler } func New( @@ -98,11 +101,12 @@ func New( profiler profiling.CmdProfilerBuilder, eventCh chan<- events.Event, idleBulkTime time.Duration, + panicHandler queue.PanicHandler, ) *Session { inputCollector := command.NewInputCollector(bufio.NewReader(conn)) scanner := rfcparser.NewScannerWithReader(inputCollector) - return &Session{ + s := &Session{ conn: conn, inputCollector: inputCollector, scanner: scanner, @@ -113,7 +117,12 @@ func New( idleBulkTime: idleBulkTime, version: version, cmdProfilerBuilder: profiler, + panicHandler: panicHandler, } + + s.handleWG.SetPanicHandler(panicHandler) + + return s } func (s *Session) SetIncomingLogger(w io.Writer) { @@ -217,6 +226,8 @@ func (s *Session) serve(ctx context.Context) error { for res := range respCh { if err := res.Send(s); err != nil { go func() { + s.handlePanic() + for range respCh { // Consume all invalid input on error that is still being produced by the ongoing // command. diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 04f05785..5548988d 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -114,6 +114,8 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons } if err := parallel.DoContext(ctx, parallelism, len(snapMessages), func(ctx context.Context, i int) error { + defer m.state.handlePanic() + msg := snapMessages[i] message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { return db.GetMessage(ctx, client, msg.ID.InternalID) diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index f0cd49e4..3b9847a8 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -57,6 +57,8 @@ func (m *Mailbox) Search(ctx context.Context, keys []command.SearchKey, decoder } if err := parallel.DoContext(ctx, parallelism, msgCount, func(ctx context.Context, i int) error { + defer m.state.handlePanic() + msg, ok := m.snap.messages.getWithSeqID(imap.SeqID(i + 1)) if !ok { return nil diff --git a/internal/state/state.go b/internal/state/state.go index 8eb68942..cf871994 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -49,6 +49,8 @@ type State struct { invalid bool imapLimits limits.IMAP + + panicHandler queue.PanicHandler } var stateIDGenerator int64 @@ -57,15 +59,22 @@ func nextStateID() StateID { return StateID(atomic.AddInt64(&stateIDGenerator, 1)) } -func NewState(user UserInterface, delimiter string, imapLimits limits.IMAP) *State { +func NewState(user UserInterface, delimiter string, imapLimits limits.IMAP, panicHandler queue.PanicHandler) *State { return &State{ user: user, StateID: nextStateID(), doneCh: make(chan struct{}), snap: nil, delimiter: delimiter, - updatesQueue: queue.NewQueuedChannel[Update](32, 128), + updatesQueue: queue.NewQueuedChannel[Update](32, 128, panicHandler), imapLimits: imapLimits, + panicHandler: panicHandler, + } +} + +func (state *State) handlePanic() { + if state.panicHandler != nil { + state.panicHandler.HandlePanic() } } diff --git a/logging/logging.go b/logging/logging.go index 4a3bd8a5..817582c5 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -20,9 +20,21 @@ const ( LineKey = "line" ) -func GoAnnotated(ctx context.Context, fn func(context.Context), labelMap ...Labels) { +type PanicHandler interface { + HandlePanic() +} + +func GoAnnotated(ctx context.Context, panicHandler PanicHandler, fn func(context.Context), labelMap ...Labels) { pprofDo(ctx, toLabelSet(labelMap...), func(ctx context.Context) { - go fn(ctx) + go func() { + defer func() { + if panicHandler != nil { + panicHandler.HandlePanic() + } + }() + + fn(ctx) + }() }) } diff --git a/option.go b/option.go index cded5967..787689b1 100644 --- a/option.go +++ b/option.go @@ -8,6 +8,7 @@ import ( "github.com/ProtonMail/gluon/imap" limits2 "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" + "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" @@ -192,6 +193,18 @@ func WithDisableParallelism() Option { return &withDisableParallelism{} } +type withPanicHandler struct { + panicHandler queue.PanicHandler +} + +func (opt *withPanicHandler) config(builder *serverBuilder) { + builder.panicHandler = opt.panicHandler +} + +func WithPanicHandler(panicHandler queue.PanicHandler) Option { + return &withPanicHandler{panicHandler} +} + type withIMAPLimits struct { limits limits2.IMAP } diff --git a/queue/queued_channel.go b/queue/queued_channel.go index 9b02e035..8623012a 100644 --- a/queue/queued_channel.go +++ b/queue/queued_channel.go @@ -18,7 +18,15 @@ type QueuedChannel[T any] struct { closed atomicBool // Should use atomic.Bool once we use Go 1.19! } -func NewQueuedChannel[T any](chanBufferSize, queueCapacity int) *QueuedChannel[T] { +type PanicHandler interface { + HandlePanic() +} + +type NoopPanicHandler struct{} + +func (n NoopPanicHandler) HandlePanic() {} + +func NewQueuedChannel[T any](chanBufferSize, queueCapacity int, panicHandler PanicHandler) *QueuedChannel[T] { queue := &QueuedChannel[T]{ ch: make(chan T, chanBufferSize), stopCh: make(chan struct{}), @@ -30,7 +38,7 @@ func NewQueuedChannel[T any](chanBufferSize, queueCapacity int) *QueuedChannel[T queue.closed.store(false) // Start the queue consumer. - logging.GoAnnotated(context.Background(), func(ctx context.Context) { + logging.GoAnnotated(context.Background(), panicHandler, func(ctx context.Context) { defer close(queue.ch) for { diff --git a/queue/queued_channel_test.go b/queue/queued_channel_test.go index 382a3f49..b6dfdd8a 100644 --- a/queue/queued_channel_test.go +++ b/queue/queued_channel_test.go @@ -11,7 +11,7 @@ func TestQueuedChannel(t *testing.T) { defer goleak.VerifyNone(t) // Create a new queued channel. - queue := NewQueuedChannel[int](3, 3) + queue := NewQueuedChannel[int](3, 3, nil) // Push some items to the queue. require.True(t, queue.Enqueue(1, 2, 3)) @@ -43,7 +43,7 @@ func TestQueuedChannelDoesNotLeakIfThereAreNoReadersOnCloseAndDiscard(t *testing defer goleak.VerifyNone(t) // Create a new queued channel. - queue := NewQueuedChannel[int](1, 3) + queue := NewQueuedChannel[int](1, 3, nil) // Push some items to the queue. require.True(t, queue.Enqueue(1, 2, 3)) diff --git a/server.go b/server.go index 5149bb72..afd984a0 100644 --- a/server.go +++ b/server.go @@ -87,6 +87,8 @@ type Server struct { disableParallelism bool uidValidityGenerator imap.UIDValidityGenerator + + panicHandler queue.PanicHandler } // New creates a new server with the given options. @@ -161,7 +163,7 @@ func (s *Server) AddWatcher(ofType ...events.Event) <-chan events.Event { s.watchersLock.Lock() defer s.watchersLock.Unlock() - watcher := watcher.New(ofType...) + watcher := watcher.New(s.panicHandler, ofType...) s.watchers = append(s.watchers, watcher) @@ -183,7 +185,7 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { Addr: l.Addr(), }) - s.serve(ctx, newConnCh(l)) + s.serve(ctx, newConnCh(l, s.panicHandler)) }) return nil @@ -194,6 +196,8 @@ func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { var connWG wait.Group defer connWG.Wait() + connWG.SetPanicHandler(s.panicHandler) + for { select { case <-ctx.Done(): @@ -288,7 +292,7 @@ func (s *Server) addSession(ctx context.Context, conn net.Conn) (*session.Sessio nextID := s.getNextID() - s.sessions[nextID] = session.New(conn, s.backend, nextID, s.versionInfo, s.cmdExecProfBuilder, s.newEventCh(ctx), s.idleBulkTime) + s.sessions[nextID] = session.New(conn, s.backend, nextID, s.versionInfo, s.cmdExecProfBuilder, s.newEventCh(ctx), s.idleBulkTime, s.panicHandler) if s.tlsConfig != nil { s.sessions[nextID].SetTLSConfig(s.tlsConfig) @@ -334,7 +338,7 @@ func (s *Server) getNextID() int { func (s *Server) newEventCh(ctx context.Context) chan events.Event { eventCh := make(chan events.Event) - logging.GoAnnotated(ctx, func(ctx context.Context) { + logging.GoAnnotated(ctx, s.panicHandler, func(ctx context.Context) { for event := range eventCh { s.publish(event) } @@ -360,10 +364,16 @@ func (s *Server) publish(event events.Event) { // newConnCh accepts connections from the given listener. // It returns a channel of all accepted connections which is closed when the listener is closed. -func newConnCh(l net.Listener) <-chan net.Conn { +func newConnCh(l net.Listener, panicHandler queue.PanicHandler) <-chan net.Conn { connCh := make(chan net.Conn) go func() { + defer func() { + if panicHandler != nil { + panicHandler.HandlePanic() + } + }() + defer close(connCh) for { diff --git a/store/semaphore.go b/store/semaphore.go index 2aebb431..30260d84 100644 --- a/store/semaphore.go +++ b/store/semaphore.go @@ -2,11 +2,17 @@ package store import "sync" +type PanicHandler interface { + HandlePanic() +} + // Semaphore implements a type used to limit concurrent operations. type Semaphore struct { ch chan struct{} wg sync.WaitGroup rw sync.RWMutex + + panicHandler PanicHandler } // NewSemaphore constructs a new semaphore with the given limit. @@ -48,8 +54,20 @@ func (sem *Semaphore) Do(fn func()) { fn() } +func (sem *Semaphore) SetPanicHandler(panicHandler PanicHandler) { + sem.panicHandler = panicHandler +} + +func (sem *Semaphore) handlePanic() { + if sem.panicHandler != nil { + sem.panicHandler.HandlePanic() + } +} + // Go executes the given function asynchronously. func (sem *Semaphore) Go(fn func()) { + defer sem.handlePanic() + sem.Lock() sem.wg.Add(1) diff --git a/tests/full_state_test.go b/tests/full_state_test.go index 0b116676..14622bcf 100644 --- a/tests/full_state_test.go +++ b/tests/full_state_test.go @@ -9,6 +9,7 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/logging" + "github.com/ProtonMail/gluon/queue" goimap "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/stretchr/testify/require" @@ -111,7 +112,7 @@ func TestReceptionOnIdle(t *testing.T) { wg.Add(2) // idling. - logging.GoAnnotated(context.Background(), func(ctx context.Context) { + logging.GoAnnotated(context.Background(), queue.NoopPanicHandler{}, func(ctx context.Context) { defer wg.Done() done <- c.Idle(stop, nil) }, logging.Labels{ @@ -120,7 +121,7 @@ func TestReceptionOnIdle(t *testing.T) { }) // receiving messages from another client. - logging.GoAnnotated(context.Background(), func(ctx context.Context) { + logging.GoAnnotated(context.Background(), queue.NoopPanicHandler{}, func(ctx context.Context) { defer wg.Done() cli := sess.newClient() diff --git a/wait/wg.go b/wait/wg.go index f1ed50fe..fe7599cb 100644 --- a/wait/wg.go +++ b/wait/wg.go @@ -2,14 +2,31 @@ package wait import "sync" +type PanicHandler interface { + HandlePanic() +} + type Group struct { - wg sync.WaitGroup + wg sync.WaitGroup + panicHandler PanicHandler +} + +func (wg *Group) SetPanicHandler(panicHandler PanicHandler) { + wg.panicHandler = panicHandler +} + +func (wg *Group) handlePanic() { + if wg.panicHandler != nil { + wg.panicHandler.HandlePanic() + } } func (wg *Group) Go(f func()) { wg.wg.Add(1) go func() { + defer wg.handlePanic() + defer wg.wg.Done() f() }() diff --git a/watcher/watcher.go b/watcher/watcher.go index 0f684437..a70f7310 100644 --- a/watcher/watcher.go +++ b/watcher/watcher.go @@ -11,7 +11,7 @@ type Watcher[T any] struct { eventCh *queue.QueuedChannel[T] } -func New[T any](ofType ...T) *Watcher[T] { +func New[T any](panicHandler queue.PanicHandler, ofType ...T) *Watcher[T] { types := make(map[reflect.Type]struct{}, len(ofType)) for _, t := range ofType { @@ -20,7 +20,7 @@ func New[T any](ofType ...T) *Watcher[T] { return &Watcher[T]{ types: types, - eventCh: queue.NewQueuedChannel[T](1, 1), + eventCh: queue.NewQueuedChannel[T](1, 1, panicHandler), } } diff --git a/watcher/watcher_test.go b/watcher/watcher_test.go index d0570bd9..842b24ea 100644 --- a/watcher/watcher_test.go +++ b/watcher/watcher_test.go @@ -4,11 +4,13 @@ import ( "testing" "github.com/ProtonMail/gluon/events" + "github.com/ProtonMail/gluon/queue" "github.com/stretchr/testify/require" ) func TestWatcher(t *testing.T) { watcher := New[events.Event]( + queue.NoopPanicHandler{}, events.ListenerAdded{}, events.ListenerRemoved{}, ) From e23a7a1be2a88652418632173b401fa3d462cb41 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 30 Mar 2023 17:49:43 +0200 Subject: [PATCH 2/2] refactor(GODT-2500): Reorganise async methods. --- {queue => async}/bool.go | 2 +- async/context.go | 65 +++++++ async/group.go | 214 ++++++++++++++++++++++++ async/panic_handler.go | 15 ++ {queue => async}/queued_channel.go | 10 +- {queue => async}/queued_channel_test.go | 6 +- async/wait_group.go | 27 +++ builder.go | 11 +- internal/backend/backend.go | 6 +- internal/backend/update_injector.go | 4 +- internal/backend/user.go | 6 +- internal/session/handle.go | 6 - internal/session/session.go | 18 +- internal/state/mailbox_fetch.go | 3 +- internal/state/mailbox_search.go | 3 +- internal/state/state.go | 16 +- option.go | 6 +- server.go | 16 +- store/semaphore.go | 26 +-- store/semaphore_test.go | 3 +- store/store_test.go | 13 +- tests/full_state_test.go | 6 +- tests/login_test.go | 4 +- wait/wg.go | 37 ---- watcher/watcher.go | 8 +- watcher/watcher_test.go | 4 +- 26 files changed, 392 insertions(+), 143 deletions(-) rename {queue => async}/bool.go (97%) create mode 100644 async/context.go create mode 100644 async/group.go create mode 100644 async/panic_handler.go rename {queue => async}/queued_channel.go (94%) rename {queue => async}/queued_channel_test.go (89%) create mode 100644 async/wait_group.go delete mode 100644 wait/wg.go diff --git a/queue/bool.go b/async/bool.go similarity index 97% rename from queue/bool.go rename to async/bool.go index 1b7de9cf..daa8b068 100644 --- a/queue/bool.go +++ b/async/bool.go @@ -1,4 +1,4 @@ -package queue +package async import "sync/atomic" diff --git a/async/context.go b/async/context.go new file mode 100644 index 00000000..8f216c57 --- /dev/null +++ b/async/context.go @@ -0,0 +1,65 @@ +package async + +import ( + "context" + "sync" +) + +// Abortable collects groups of functions that can be aborted by calling Abort. +type Abortable struct { + abortFunc []context.CancelFunc + abortLock sync.RWMutex +} + +func (a *Abortable) Do(ctx context.Context, fn func(context.Context)) { + fn(a.newCancelCtx(ctx)) +} + +func (a *Abortable) Abort() { + a.abortLock.RLock() + defer a.abortLock.RUnlock() + + for _, fn := range a.abortFunc { + fn() + } +} + +func (a *Abortable) newCancelCtx(ctx context.Context) context.Context { + a.abortLock.Lock() + defer a.abortLock.Unlock() + + ctx, cancel := context.WithCancel(ctx) + + a.abortFunc = append(a.abortFunc, cancel) + + return ctx +} + +// RangeContext iterates over the given channel until the context is canceled or the +// channel is closed. +func RangeContext[T any](ctx context.Context, ch <-chan T, fn func(T)) { + for { + select { + case v, ok := <-ch: + if !ok { + return + } + + fn(v) + + case <-ctx.Done(): + return + } + } +} + +// ForwardContext forwards all values from the src channel to the dst channel until the +// context is canceled or the src channel is closed. +func ForwardContext[T any](ctx context.Context, dst chan<- T, src <-chan T) { + RangeContext(ctx, src, func(v T) { + select { + case dst <- v: + case <-ctx.Done(): + } + }) +} diff --git a/async/group.go b/async/group.go new file mode 100644 index 00000000..7e045d95 --- /dev/null +++ b/async/group.go @@ -0,0 +1,214 @@ +package async + +import ( + "context" + "math/rand" + "sync" + "time" +) + +// Group is forked and improved version of "github.com/bradenaw/juniper/xsync.Group". +// +// It manages a group of goroutines. The main change to original is posibility +// to wait passed function to finish without canceling it's context and adding +// PanicHandler. +type Group struct { + baseCtx context.Context + ctx context.Context + jobCtx context.Context + cancel context.CancelFunc + finish context.CancelFunc + wg sync.WaitGroup + + panicHandler PanicHandler +} + +// NewGroup returns a Group ready for use. The context passed to any of the f functions will be a +// descendant of ctx. +func NewGroup(ctx context.Context, panicHandler PanicHandler) *Group { + bgCtx, cancel := context.WithCancel(ctx) + jobCtx, finish := context.WithCancel(ctx) + + return &Group{ + baseCtx: ctx, + ctx: bgCtx, + jobCtx: jobCtx, + cancel: cancel, + finish: finish, + panicHandler: panicHandler, + } +} + +// Once calls f once from another goroutine. +func (g *Group) Once(f func(ctx context.Context)) { + g.wg.Add(1) + + go func() { + defer HandlePanic(g.panicHandler) + + f(g.ctx) + g.wg.Done() + }() +} + +// jitterDuration returns a random duration in [d - jitter, d + jitter]. +func jitterDuration(d time.Duration, jitter time.Duration) time.Duration { + return d + time.Duration(float64(jitter)*((rand.Float64()*2)-1)) //nolint:gosec +} + +// Periodic spawns a goroutine that calls f once per interval +/- jitter. +func (g *Group) Periodic( + interval time.Duration, + jitter time.Duration, + f func(ctx context.Context), +) { + g.wg.Add(1) + + go func() { + defer HandlePanic(g.panicHandler) + + defer g.wg.Done() + + t := time.NewTimer(jitterDuration(interval, jitter)) + defer t.Stop() + + for { + if g.ctx.Err() != nil { + return + } + + select { + case <-g.jobCtx.Done(): + return + case <-t.C: + } + + t.Reset(jitterDuration(interval, jitter)) + f(g.ctx) + } + }() +} + +// Trigger spawns a goroutine which calls f whenever the returned function is called. If f is +// already running when triggered, f will run again immediately when it finishes. +func (g *Group) Trigger(f func(ctx context.Context)) func() { + c := make(chan struct{}, 1) + + g.wg.Add(1) + + go func() { + defer HandlePanic(g.panicHandler) + + defer g.wg.Done() + + for { + if g.ctx.Err() != nil { + return + } + select { + case <-g.jobCtx.Done(): + return + case <-c: + } + f(g.ctx) + } + }() + + return func() { + select { + case c <- struct{}{}: + default: + } + } +} + +// PeriodicOrTrigger spawns a goroutine which calls f whenever the returned function is called. If +// f is already running when triggered, f will run again immediately when it finishes. Also calls f +// when it has been interval+/-jitter since the last trigger. +func (g *Group) PeriodicOrTrigger( + interval time.Duration, + jitter time.Duration, + f func(ctx context.Context), +) func() { + c := make(chan struct{}, 1) + + g.wg.Add(1) + + go func() { + defer HandlePanic(g.panicHandler) + + defer g.wg.Done() + + t := time.NewTimer(jitterDuration(interval, jitter)) + defer t.Stop() + + for { + if g.ctx.Err() != nil { + return + } + select { + case <-g.jobCtx.Done(): + return + case <-t.C: + t.Reset(jitterDuration(interval, jitter)) + case <-c: + if !t.Stop() { + <-t.C + } + + t.Reset(jitterDuration(interval, jitter)) + } + f(g.ctx) + } + }() + + return func() { + select { + case c <- struct{}{}: + default: + } + } +} + +func (g *Group) resetCtx() { + g.jobCtx, g.finish = context.WithCancel(g.baseCtx) + g.ctx, g.cancel = context.WithCancel(g.baseCtx) +} + +// Cancel is send to all of the spawn goroutines and ends periodic +// or trigger routines. +func (g *Group) Cancel() { + g.cancel() + g.finish() + g.resetCtx() +} + +// Finish will ends all periodic or polls routines. It will let +// currently running functions to finish (cancel is not sent). +// +// It is not safe to call Wait concurrently with any other method on g. +func (g *Group) Finish() { + g.finish() + g.jobCtx, g.finish = context.WithCancel(g.baseCtx) +} + +// CancelAndWait cancels the context passed to any of the spawned goroutines and waits for all spawned +// goroutines to exit. +// +// It is not safe to call Wait concurrently with any other method on g. +func (g *Group) CancelAndWait() { + g.finish() + g.cancel() + g.wg.Wait() + g.resetCtx() +} + +// WaitToFinish will ends all periodic or polls routines. It will wait for +// currently running functions to finish (cancel is not sent). +// +// It is not safe to call Wait concurrently with any other method on g. +func (g *Group) WaitToFinish() { + g.finish() + g.wg.Wait() + g.jobCtx, g.finish = context.WithCancel(g.baseCtx) +} diff --git a/async/panic_handler.go b/async/panic_handler.go new file mode 100644 index 00000000..144a9dab --- /dev/null +++ b/async/panic_handler.go @@ -0,0 +1,15 @@ +package async + +type PanicHandler interface { + HandlePanic() +} + +type NoopPanicHandler struct{} + +func (n NoopPanicHandler) HandlePanic() {} + +func HandlePanic(panicHandler PanicHandler) { + if panicHandler != nil { + panicHandler.HandlePanic() + } +} diff --git a/queue/queued_channel.go b/async/queued_channel.go similarity index 94% rename from queue/queued_channel.go rename to async/queued_channel.go index 8623012a..1b87d561 100644 --- a/queue/queued_channel.go +++ b/async/queued_channel.go @@ -1,4 +1,4 @@ -package queue +package async import ( "context" @@ -18,14 +18,6 @@ type QueuedChannel[T any] struct { closed atomicBool // Should use atomic.Bool once we use Go 1.19! } -type PanicHandler interface { - HandlePanic() -} - -type NoopPanicHandler struct{} - -func (n NoopPanicHandler) HandlePanic() {} - func NewQueuedChannel[T any](chanBufferSize, queueCapacity int, panicHandler PanicHandler) *QueuedChannel[T] { queue := &QueuedChannel[T]{ ch: make(chan T, chanBufferSize), diff --git a/queue/queued_channel_test.go b/async/queued_channel_test.go similarity index 89% rename from queue/queued_channel_test.go rename to async/queued_channel_test.go index b6dfdd8a..e4fb99f6 100644 --- a/queue/queued_channel_test.go +++ b/async/queued_channel_test.go @@ -1,4 +1,4 @@ -package queue +package async import ( "testing" @@ -11,7 +11,7 @@ func TestQueuedChannel(t *testing.T) { defer goleak.VerifyNone(t) // Create a new queued channel. - queue := NewQueuedChannel[int](3, 3, nil) + queue := NewQueuedChannel[int](3, 3, NoopPanicHandler{}) // Push some items to the queue. require.True(t, queue.Enqueue(1, 2, 3)) @@ -43,7 +43,7 @@ func TestQueuedChannelDoesNotLeakIfThereAreNoReadersOnCloseAndDiscard(t *testing defer goleak.VerifyNone(t) // Create a new queued channel. - queue := NewQueuedChannel[int](1, 3, nil) + queue := NewQueuedChannel[int](1, 3, NoopPanicHandler{}) // Push some items to the queue. require.True(t, queue.Enqueue(1, 2, 3)) diff --git a/async/wait_group.go b/async/wait_group.go new file mode 100644 index 00000000..df5d722c --- /dev/null +++ b/async/wait_group.go @@ -0,0 +1,27 @@ +package async + +import "sync" + +type WaitGroup struct { + wg sync.WaitGroup + panicHandler PanicHandler +} + +func MakeWaitGroup(panicHandler PanicHandler) WaitGroup { + return WaitGroup{panicHandler: panicHandler} +} + +func (wg *WaitGroup) Go(f func()) { + wg.wg.Add(1) + + go func() { + defer HandlePanic(wg.panicHandler) + + defer wg.wg.Done() + f() + }() +} + +func (wg *WaitGroup) Wait() { + wg.wg.Wait() +} diff --git a/builder.go b/builder.go index 4a6296a2..864388fd 100644 --- a/builder.go +++ b/builder.go @@ -6,13 +6,13 @@ import ( "os" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" "github.com/ProtonMail/gluon/internal/db" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" @@ -35,7 +35,7 @@ type serverBuilder struct { disableParallelism bool imapLimits limits.IMAP uidValidityGenerator imap.UIDValidityGenerator - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } func newBuilder() (*serverBuilder, error) { @@ -47,7 +47,7 @@ func newBuilder() (*serverBuilder, error) { idleBulkTime: 500 * time.Millisecond, imapLimits: limits.DefaultLimits(), uidValidityGenerator: imap.DefaultEpochUIDValidityGenerator(), - panicHandler: queue.NoopPanicHandler{}, + panicHandler: async.NoopPanicHandler{}, }, nil } @@ -102,8 +102,9 @@ func (builder *serverBuilder) build() (*Server, error) { databaseDir: builder.databaseDir, backend: backend, sessions: make(map[int]*session.Session), - serveErrCh: queue.NewQueuedChannel[error](1, 1, builder.panicHandler), + serveErrCh: async.NewQueuedChannel[error](1, 1, builder.panicHandler), serveDoneCh: make(chan struct{}), + serveWG: async.MakeWaitGroup(builder.panicHandler), inLogger: builder.inLogger, outLogger: builder.outLogger, tlsConfig: builder.tlsConfig, @@ -117,7 +118,5 @@ func (builder *serverBuilder) build() (*Server, error) { panicHandler: builder.panicHandler, } - s.serveWG.SetPanicHandler(builder.panicHandler) - return s, nil } diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 6a8bd4fa..5666919b 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" @@ -14,7 +15,6 @@ import ( "github.com/ProtonMail/gluon/internal/db/ent/mailbox" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/google/uuid" @@ -51,10 +51,10 @@ type Backend struct { imapLimits limits.IMAP - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } -func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP, panicHandler queue.PanicHandler) (*Backend, error) { +func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP, panicHandler async.PanicHandler) (*Backend, error) { return &Backend{ dataDir: dataDir, databaseDir: databaseDir, diff --git a/internal/backend/update_injector.go b/internal/backend/update_injector.go index 5e6851be..ddc65dcf 100644 --- a/internal/backend/update_injector.go +++ b/internal/backend/update_injector.go @@ -4,10 +4,10 @@ import ( "context" "sync" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/logging" - "github.com/ProtonMail/gluon/queue" ) // updateInjector allows anyone to publish custom imap updates alongside the updates that are generated from the @@ -24,7 +24,7 @@ type updateInjector struct { // newUpdateInjector creates a new updateInjector. // // nolint:contextcheck -func newUpdateInjector(connector connector.Connector, userID string, panicHandler queue.PanicHandler) *updateInjector { +func newUpdateInjector(connector connector.Connector, userID string, panicHandler async.PanicHandler) *updateInjector { injector := &updateInjector{ updatesCh: make(chan imap.Update), forwardQuitCh: make(chan struct{}), diff --git a/internal/backend/user.go b/internal/backend/user.go index 0bd3fa2a..e97e7925 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" @@ -13,7 +14,6 @@ import ( "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/logging" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/bradenaw/juniper/xslices" @@ -44,7 +44,7 @@ type user struct { uidValidityGenerator imap.UIDValidityGenerator - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } func newUser( @@ -56,7 +56,7 @@ func newUser( delimiter string, imapLimits limits.IMAP, uidValidityGenerator imap.UIDValidityGenerator, - panicHandler queue.PanicHandler, + panicHandler async.PanicHandler, ) (*user, error) { if err := database.Init(ctx); err != nil { return nil, err diff --git a/internal/session/handle.go b/internal/session/handle.go index b115802a..6935fc03 100644 --- a/internal/session/handle.go +++ b/internal/session/handle.go @@ -10,12 +10,6 @@ import ( "github.com/ProtonMail/gluon/logging" ) -func (s *Session) handlePanic() { - if s.panicHandler != nil { - s.panicHandler.HandlePanic() - } -} - func (s *Session) handleOther( ctx context.Context, tag string, diff --git a/internal/session/session.go b/internal/session/session.go index 512b267a..d29297b1 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" @@ -22,11 +23,9 @@ import ( "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/rfcparser" "github.com/ProtonMail/gluon/version" - "github.com/ProtonMail/gluon/wait" "github.com/emersion/go-imap/utf7" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" @@ -83,14 +82,14 @@ type Session struct { cmdProfilerBuilder profiling.CmdProfilerBuilder // handleWG is used to wait for all commands to finish before closing the session. - handleWG wait.Group + handleWG async.WaitGroup /// errorCount error counter errorCount int imapLimits limits.IMAP - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } func New( @@ -101,12 +100,12 @@ func New( profiler profiling.CmdProfilerBuilder, eventCh chan<- events.Event, idleBulkTime time.Duration, - panicHandler queue.PanicHandler, + panicHandler async.PanicHandler, ) *Session { inputCollector := command.NewInputCollector(bufio.NewReader(conn)) scanner := rfcparser.NewScannerWithReader(inputCollector) - s := &Session{ + return &Session{ conn: conn, inputCollector: inputCollector, scanner: scanner, @@ -117,12 +116,9 @@ func New( idleBulkTime: idleBulkTime, version: version, cmdProfilerBuilder: profiler, + handleWG: async.MakeWaitGroup(panicHandler), panicHandler: panicHandler, } - - s.handleWG.SetPanicHandler(panicHandler) - - return s } func (s *Session) SetIncomingLogger(w io.Writer) { @@ -226,7 +222,7 @@ func (s *Session) serve(ctx context.Context) error { for res := range respCh { if err := res.Send(s); err != nil { go func() { - s.handlePanic() + defer async.HandlePanic(s.panicHandler) for range respCh { // Consume all invalid input on error that is still being produced by the ongoing diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 5548988d..0fa379e0 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -8,6 +8,7 @@ import ( "strings" "sync/atomic" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" @@ -114,7 +115,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons } if err := parallel.DoContext(ctx, parallelism, len(snapMessages), func(ctx context.Context, i int) error { - defer m.state.handlePanic() + defer async.HandlePanic(m.state.panicHandler) msg := snapMessages[i] message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 3b9847a8..6a4cce41 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" @@ -57,7 +58,7 @@ func (m *Mailbox) Search(ctx context.Context, keys []command.SearchKey, decoder } if err := parallel.DoContext(ctx, parallelism, msgCount, func(ctx context.Context, i int) error { - defer m.state.handlePanic() + defer async.HandlePanic(m.state.panicHandler) msg, ok := m.snap.messages.getWithSeqID(imap.SeqID(i + 1)) if !ok { diff --git a/internal/state/state.go b/internal/state/state.go index cf871994..d113fde0 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -7,13 +7,13 @@ import ( "strings" "sync/atomic" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/limits" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/sets" @@ -41,7 +41,7 @@ type State struct { doneCh chan struct{} - updatesQueue *queue.QueuedChannel[Update] + updatesQueue *async.QueuedChannel[Update] delimiter string @@ -50,7 +50,7 @@ type State struct { imapLimits limits.IMAP - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } var stateIDGenerator int64 @@ -59,25 +59,19 @@ func nextStateID() StateID { return StateID(atomic.AddInt64(&stateIDGenerator, 1)) } -func NewState(user UserInterface, delimiter string, imapLimits limits.IMAP, panicHandler queue.PanicHandler) *State { +func NewState(user UserInterface, delimiter string, imapLimits limits.IMAP, panicHandler async.PanicHandler) *State { return &State{ user: user, StateID: nextStateID(), doneCh: make(chan struct{}), snap: nil, delimiter: delimiter, - updatesQueue: queue.NewQueuedChannel[Update](32, 128, panicHandler), + updatesQueue: async.NewQueuedChannel[Update](32, 128, panicHandler), imapLimits: imapLimits, panicHandler: panicHandler, } } -func (state *State) handlePanic() { - if state.panicHandler != nil { - state.panicHandler.HandlePanic() - } -} - func (state *State) UserID() string { return state.user.GetUserID() } diff --git a/option.go b/option.go index 787689b1..cd160bc9 100644 --- a/option.go +++ b/option.go @@ -5,10 +5,10 @@ import ( "io" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" limits2 "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" @@ -194,14 +194,14 @@ func WithDisableParallelism() Option { } type withPanicHandler struct { - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } func (opt *withPanicHandler) config(builder *serverBuilder) { builder.panicHandler = opt.panicHandler } -func WithPanicHandler(panicHandler queue.PanicHandler) Option { +func WithPanicHandler(panicHandler async.PanicHandler) Option { return &withPanicHandler{panicHandler} } diff --git a/server.go b/server.go index afd984a0..4d68a9fd 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" @@ -19,11 +20,9 @@ import ( "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/profiling" - "github.com/ProtonMail/gluon/queue" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/version" - "github.com/ProtonMail/gluon/wait" "github.com/ProtonMail/gluon/watcher" _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" @@ -45,13 +44,13 @@ type Server struct { sessionsLock sync.RWMutex // serveErrCh collects errors encountered while serving. - serveErrCh *queue.QueuedChannel[error] + serveErrCh *async.QueuedChannel[error] // serveDoneCh is used to stop the server. serveDoneCh chan struct{} // serveWG keeps track of serving goroutines. - serveWG wait.Group + serveWG async.WaitGroup // nextID holds the ID that will be given to the next session. nextID int @@ -88,7 +87,7 @@ type Server struct { uidValidityGenerator imap.UIDValidityGenerator - panicHandler queue.PanicHandler + panicHandler async.PanicHandler } // New creates a new server with the given options. @@ -193,10 +192,7 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { // serve handles incoming connections and starts a new goroutine for each. func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { - var connWG wait.Group - defer connWG.Wait() - - connWG.SetPanicHandler(s.panicHandler) + connWG := async.MakeWaitGroup(s.panicHandler) for { select { @@ -364,7 +360,7 @@ func (s *Server) publish(event events.Event) { // newConnCh accepts connections from the given listener. // It returns a channel of all accepted connections which is closed when the listener is closed. -func newConnCh(l net.Listener, panicHandler queue.PanicHandler) <-chan net.Conn { +func newConnCh(l net.Listener, panicHandler async.PanicHandler) <-chan net.Conn { connCh := make(chan net.Conn) go func() { diff --git a/store/semaphore.go b/store/semaphore.go index 30260d84..d8992e81 100644 --- a/store/semaphore.go +++ b/store/semaphore.go @@ -1,10 +1,10 @@ package store -import "sync" +import ( + "sync" -type PanicHandler interface { - HandlePanic() -} + "github.com/ProtonMail/gluon/async" +) // Semaphore implements a type used to limit concurrent operations. type Semaphore struct { @@ -12,12 +12,12 @@ type Semaphore struct { wg sync.WaitGroup rw sync.RWMutex - panicHandler PanicHandler + panicHandler async.PanicHandler } // NewSemaphore constructs a new semaphore with the given limit. -func NewSemaphore(max int) *Semaphore { - return &Semaphore{ch: make(chan struct{}, max)} +func NewSemaphore(max int, panicHandler async.PanicHandler) *Semaphore { + return &Semaphore{ch: make(chan struct{}, max), panicHandler: panicHandler} } // Lock locks the semaphore, waiting first until it is possible. @@ -54,19 +54,9 @@ func (sem *Semaphore) Do(fn func()) { fn() } -func (sem *Semaphore) SetPanicHandler(panicHandler PanicHandler) { - sem.panicHandler = panicHandler -} - -func (sem *Semaphore) handlePanic() { - if sem.panicHandler != nil { - sem.panicHandler.HandlePanic() - } -} - // Go executes the given function asynchronously. func (sem *Semaphore) Go(fn func()) { - defer sem.handlePanic() + defer async.HandlePanic(sem.panicHandler) sem.Lock() sem.wg.Add(1) diff --git a/store/semaphore_test.go b/store/semaphore_test.go index 1726b35e..4c04c9ad 100644 --- a/store/semaphore_test.go +++ b/store/semaphore_test.go @@ -4,11 +4,12 @@ import ( "sync" "testing" + "github.com/ProtonMail/gluon/async" "github.com/stretchr/testify/assert" ) func TestSemaphore(t *testing.T) { - sem := NewSemaphore(4) + sem := NewSemaphore(4, async.NoopPanicHandler{}) // Block the semaphore so that tasks wait until we unblock it. sem.Block() diff --git a/store/store_test.go b/store/store_test.go index a0d51294..0ddc6413 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -8,6 +8,7 @@ import ( "runtime" "testing" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/store" "github.com/ProtonMail/gluon/store/fallback_v0" @@ -18,7 +19,7 @@ func TestStore_DecryptFailedOnFilesBiggerThanBlockSize(t *testing.T) { store, err := store.NewOnDiskStore( t.TempDir(), []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), ) require.NoError(t, err) @@ -40,7 +41,7 @@ func BenchmarkStoreRead(t *testing.B) { store, err := store.NewOnDiskStore( t.TempDir(), []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), ) require.NoError(t, err) @@ -66,7 +67,7 @@ func TestStoreReadFailsIfHeaderDoesNotMatch(t *testing.T) { store, err := store.NewOnDiskStore( storeDir, []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), ) require.NoError(t, err) @@ -109,7 +110,7 @@ func TestStoreFallbackRead(t *testing.T) { store, err := store.NewOnDiskStore( storeDir, []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), ) require.NoError(t, err) defer func() { @@ -125,7 +126,7 @@ func TestStoreFallbackRead(t *testing.T) { store, err := store.NewOnDiskStore( storeDir, []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), store.WithFallback(fallbackStore), ) require.NoError(t, err) @@ -143,7 +144,7 @@ func TestOnDiskStore(t *testing.T) { store, err := store.NewOnDiskStore( t.TempDir(), []byte("pass"), - store.WithSemaphore(store.NewSemaphore(runtime.NumCPU())), + store.WithSemaphore(store.NewSemaphore(runtime.NumCPU(), async.NoopPanicHandler{})), ) require.NoError(t, err) diff --git a/tests/full_state_test.go b/tests/full_state_test.go index 14622bcf..c52ef981 100644 --- a/tests/full_state_test.go +++ b/tests/full_state_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/logging" - "github.com/ProtonMail/gluon/queue" goimap "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/stretchr/testify/require" @@ -112,7 +112,7 @@ func TestReceptionOnIdle(t *testing.T) { wg.Add(2) // idling. - logging.GoAnnotated(context.Background(), queue.NoopPanicHandler{}, func(ctx context.Context) { + logging.GoAnnotated(context.Background(), async.NoopPanicHandler{}, func(ctx context.Context) { defer wg.Done() done <- c.Idle(stop, nil) }, logging.Labels{ @@ -121,7 +121,7 @@ func TestReceptionOnIdle(t *testing.T) { }) // receiving messages from another client. - logging.GoAnnotated(context.Background(), queue.NoopPanicHandler{}, func(ctx context.Context) { + logging.GoAnnotated(context.Background(), async.NoopPanicHandler{}, func(ctx context.Context) { defer wg.Done() cli := sess.newClient() diff --git a/tests/login_test.go b/tests/login_test.go index 7aa8cb82..a1a2b7c6 100644 --- a/tests/login_test.go +++ b/tests/login_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/events" - "github.com/ProtonMail/gluon/wait" "github.com/stretchr/testify/require" ) @@ -120,7 +120,7 @@ func TestLoginTooManyAttemptsMany(t *testing.T) { c[2].C("A002 login user badpass").NO("A002") c[3].C("A003 login user badpass").NO("A003") - var wg wait.Group + wg := async.MakeWaitGroup(async.NoopPanicHandler{}) // All clients should be jailed for 1 sec. for _, i := range []int{1, 2, 3} { diff --git a/wait/wg.go b/wait/wg.go deleted file mode 100644 index fe7599cb..00000000 --- a/wait/wg.go +++ /dev/null @@ -1,37 +0,0 @@ -package wait - -import "sync" - -type PanicHandler interface { - HandlePanic() -} - -type Group struct { - wg sync.WaitGroup - panicHandler PanicHandler -} - -func (wg *Group) SetPanicHandler(panicHandler PanicHandler) { - wg.panicHandler = panicHandler -} - -func (wg *Group) handlePanic() { - if wg.panicHandler != nil { - wg.panicHandler.HandlePanic() - } -} - -func (wg *Group) Go(f func()) { - wg.wg.Add(1) - - go func() { - defer wg.handlePanic() - - defer wg.wg.Done() - f() - }() -} - -func (wg *Group) Wait() { - wg.wg.Wait() -} diff --git a/watcher/watcher.go b/watcher/watcher.go index a70f7310..2070335c 100644 --- a/watcher/watcher.go +++ b/watcher/watcher.go @@ -3,15 +3,15 @@ package watcher import ( "reflect" - "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/gluon/async" ) type Watcher[T any] struct { types map[reflect.Type]struct{} - eventCh *queue.QueuedChannel[T] + eventCh *async.QueuedChannel[T] } -func New[T any](panicHandler queue.PanicHandler, ofType ...T) *Watcher[T] { +func New[T any](panicHandler async.PanicHandler, ofType ...T) *Watcher[T] { types := make(map[reflect.Type]struct{}, len(ofType)) for _, t := range ofType { @@ -20,7 +20,7 @@ func New[T any](panicHandler queue.PanicHandler, ofType ...T) *Watcher[T] { return &Watcher[T]{ types: types, - eventCh: queue.NewQueuedChannel[T](1, 1, panicHandler), + eventCh: async.NewQueuedChannel[T](1, 1, panicHandler), } } diff --git a/watcher/watcher_test.go b/watcher/watcher_test.go index 842b24ea..734a920c 100644 --- a/watcher/watcher_test.go +++ b/watcher/watcher_test.go @@ -3,14 +3,14 @@ package watcher import ( "testing" + "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/events" - "github.com/ProtonMail/gluon/queue" "github.com/stretchr/testify/require" ) func TestWatcher(t *testing.T) { watcher := New[events.Event]( - queue.NoopPanicHandler{}, + async.NoopPanicHandler{}, events.ListenerAdded{}, events.ListenerRemoved{}, )