Skip to content

Commit

Permalink
fix(GODT-1585): Fix race conditions in snapshot access
Browse files Browse the repository at this point in the history
Ensure snapshot access is protected by RWLock as there are cases where
we are closing the snapshot while it is being read when processing
Connector updates.

All access to the snapshot should go through the `snapshotWrapper` type.
Due to limitations of the Go generic system, standalone functions
`snapshotRead` and `snapshotWrite` need to be used.
  • Loading branch information
LBeernaertProton authored and jameshoulahan committed Aug 26, 2022
1 parent e5c7257 commit 5f11d7e
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 107 deletions.
71 changes: 52 additions & 19 deletions internal/backend/mailbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,25 @@ type Mailbox struct {
mbox *ent.Mailbox

state *State
snap *snapshot
snap *snapshotWrapper

selected bool
readOnly bool
}

func newMailbox(mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox {
func newMailbox(mbox *ent.Mailbox, state *State, wrapper *snapshotWrapper) *Mailbox {
selected := snapshotRead(wrapper, func(s *snapshot) bool {
return s != nil
})

return &Mailbox{
mbox: mbox,

state: state,
snap: snap,

selected: state.snap != nil,
selected: selected,
readOnly: state.ro,
snap: wrapper,
}
}

Expand Down Expand Up @@ -62,7 +66,9 @@ func (m *Mailbox) ExpungeIssued() bool {
}

func (m *Mailbox) Count() int {
return len(m.snap.getAllMessages())
return snapshotRead(m.snap, func(s *snapshot) int {
return len(s.getAllMessages())
})
}

func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) {
Expand Down Expand Up @@ -111,14 +117,18 @@ func (m *Mailbox) Subscribed() bool {
}

func (m *Mailbox) GetMessagesWithFlag(flag string) []int {
return xslices.Map(m.snap.getMessagesWithFlag(flag), func(msg *snapMsg) int {
return msg.Seq
return snapshotRead(m.snap, func(s *snapshot) []int {
return xslices.Map(s.getMessagesWithFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
})
}

func (m *Mailbox) GetMessagesWithoutFlag(flag string) []int {
return xslices.Map(m.snap.getMessagesWithoutFlag(flag), func(msg *snapMsg) int {
return msg.Seq
return snapshotRead(m.snap, func(s *snapshot) []int {
return xslices.Map(s.getMessagesWithoutFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
})
}

Expand All @@ -144,8 +154,12 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet
}
}

snapMBoxID := snapshotRead(m.snap, func(s *snapshot) MailboxIDPair {
return s.mboxID
})

return DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (int, error) {
return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date)
return m.state.actionCreateMessage(ctx, tx, snapMBoxID, literal, flags, date)
})
}

Expand All @@ -160,7 +174,9 @@ func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string)
return nil, ErrNoSuchMailbox
}

messages, err := m.snap.getMessagesInRange(ctx, seq)
messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -198,11 +214,17 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
mbox, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
return DBGetMailboxByName(ctx, client, name)
})

if err != nil {
return nil, ErrNoSuchMailbox
}

messages, err := m.snap.getMessagesInRange(ctx, seq)
var snapMBoxID MailboxIDPair

messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
snapMBoxID = s.mboxID
return s.getMessagesInRange(ctx, seq)
})
if err != nil {
return nil, err
}
Expand All @@ -216,7 +238,7 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
})

destUIDs, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, NewMailboxIDPair(mbox))
return m.state.actionMoveMessages(ctx, tx, msgIDs, snapMBoxID, NewMailboxIDPair(mbox))
})
if err != nil {
return nil, err
Expand All @@ -234,7 +256,9 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
}

func (m *Mailbox) Store(ctx context.Context, seq *proto.SequenceSet, operation proto.Operation, flags imap.FlagSet) error {
messages, err := m.snap.getMessagesInRange(ctx, seq)
messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
if err != nil {
return err
}
Expand Down Expand Up @@ -269,13 +293,18 @@ func (m *Mailbox) Expunge(ctx context.Context, seq *proto.SequenceSet) error {
var msg []*snapMsg

if seq != nil {
var err error

if msg, err = m.snap.getMessagesInRange(ctx, seq); err != nil {
snapMsgs, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
if err != nil {
return err
}

msg = snapMsgs
} else {
msg = m.snap.getAllMessages()
msg = snapshotRead(m.snap, func(s *snapshot) []*snapMsg {
return s.getAllMessages()
})
}

return m.expunge(ctx, msg)
Expand All @@ -290,8 +319,12 @@ func (m *Mailbox) expunge(ctx context.Context, messages []*snapMsg) error {
return msg.ID
})

mboxID := snapshotRead(m.snap, func(s *snapshot) MailboxIDPair {
return s.mboxID
})

return m.state.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID)
return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, mboxID)
})
}

Expand Down
4 changes: 3 additions & 1 deletion internal/backend/mailbox_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (
)

func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes []*proto.FetchAttribute, ch chan response.Response) error {
msg, err := m.snap.getMessagesInRange(ctx, seq)
msg, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
if err != nil {
return err
}
Expand Down
14 changes: 11 additions & 3 deletions internal/backend/mailbox_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ import (
)

func (m *Mailbox) Search(ctx context.Context, keys []*proto.SearchKey, decoder *encoding.Decoder) ([]int, error) {
messages, err := doSearch(ctx, m, m.snap.getAllMessages(), keys, decoder)
snapMessages := snapshotRead(m.snap, func(s *snapshot) []*snapMsg {
return s.getAllMessages()
})

messages, err := doSearch(ctx, m, snapMessages, keys, decoder)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -581,7 +585,9 @@ func (m *Mailbox) matchSearchKeyTo(ctx context.Context, candidates []*snapMsg, k
}

func (m *Mailbox) matchSearchKeyUID(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
left, err := m.snap.getMessagesInUIDRange(key.GetSequenceSet())
left, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInUIDRange(key.GetSequenceSet())
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -628,7 +634,9 @@ func (m *Mailbox) matchSearchKeyUnseen(ctx context.Context, candidates []*snapMs
}

func (m *Mailbox) matchSearchKeySeqSet(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
left, err := m.snap.getMessagesInSeqRange(key.GetSequenceSet())
left, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInSeqRange(key.GetSequenceSet())
})
if err != nil {
return nil, err
}
Expand Down
47 changes: 47 additions & 0 deletions internal/backend/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strconv"
"sync"

"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/internal/backend/ent"
Expand All @@ -18,6 +19,52 @@ type snapshot struct {
messages *snapMsgList
}

type snapshotWrapper struct {
lock sync.RWMutex
snap *snapshot
}

func newSnapshotWrapper(snapshot *snapshot) *snapshotWrapper {
return &snapshotWrapper{
snap: snapshot,
}
}

func (s *snapshotWrapper) Replace(snap *snapshot) {
s.lock.Lock()
defer s.lock.Unlock()

s.snap = snap
}

func snapshotWrite[T any](wrapper *snapshotWrapper, fn func(*snapshot) T) T {
wrapper.lock.Lock()
defer wrapper.lock.Unlock()

return fn(wrapper.snap)
}

func snapshotRead[T any](wrapper *snapshotWrapper, fn func(*snapshot) T) T {
wrapper.lock.RLock()
defer wrapper.lock.RUnlock()

return fn(wrapper.snap)
}

func snapshotWriteErr[T any](wrapper *snapshotWrapper, fn func(*snapshot) (T, error)) (T, error) {
wrapper.lock.Lock()
defer wrapper.lock.Unlock()

return fn(wrapper.snap)
}

func snapshotReadErr[T any](wrapper *snapshotWrapper, fn func(*snapshot) (T, error)) (T, error) {
wrapper.lock.RLock()
defer wrapper.lock.RUnlock()

return fn(wrapper.snap)
}

func newSnapshot(ctx context.Context, state *State, client *ent.Client, mbox *ent.Mailbox) (*snapshot, error) {
msgUIDs, err := DBGetMailboxMessagesForNewSnapshot(ctx, client, mbox)
if err != nil {
Expand Down
Loading

0 comments on commit 5f11d7e

Please sign in to comment.