Skip to content

Commit

Permalink
feat(GODT-1720, GODT-1722): Implement Database Wrapper
Browse files Browse the repository at this point in the history
Implements database as discussed in the RFC. It makes distinctions
between read and write operations.

Currently, due to the use of SQLite we only support only one concurrent
writer. This is enforced by a RWLock around the `Read` and `Write`
functions.

Updates to rest of the codebase will follow in future patches.
  • Loading branch information
LBeernaertProton committed Aug 11, 2022
1 parent d17a5f2 commit cc387f0
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 49 deletions.
1 change: 1 addition & 0 deletions benchmarks/gluon_bench/store_benchmarks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package store_benchmarks

import (
"context"

"github.com/ProtonMail/gluon/benchmarks/gluon_bench/benchmark"
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/flags"
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/reporter"
Expand Down
5 changes: 2 additions & 3 deletions internal/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"sync"

"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/internal/backend/ent"
"github.com/ProtonMail/gluon/internal/remote"
"github.com/ProtonMail/gluon/store"
"github.com/google/uuid"
Expand Down Expand Up @@ -46,7 +45,7 @@ func (b *Backend) SetDelimiter(delim string) {
b.delim = delim
}

func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Connector, store store.Store, client *ent.Client) error {
func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Connector, store store.Store, db *DB) error {
b.usersLock.Lock()
defer b.usersLock.Unlock()

Expand All @@ -55,7 +54,7 @@ func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Con
return err
}

user, err := newUser(ctx, userID, client, remote, store, b.delim)
user, err := newUser(ctx, userID, db, remote, store, b.delim)
if err != nil {
return err
}
Expand Down
83 changes: 83 additions & 0 deletions internal/backend/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package backend

import (
"context"
"fmt"
"path/filepath"
"sync"

"entgo.io/ent/dialect"
"github.com/ProtonMail/gluon/internal/backend/ent"
)

type DB struct {
db *ent.Client
lock sync.RWMutex
}

func (d *DB) Init(ctx context.Context) error {
d.lock.Lock()
defer d.lock.Unlock()

return d.db.Schema.Create(ctx)
}

func (d *DB) Read(ctx context.Context, fn func(context.Context, *ent.Client) error) error {
d.lock.RLock()
defer d.lock.Unlock()

return fn(ctx, d.db)
}

func (d *DB) Write(ctx context.Context, fn func(context.Context, *ent.Tx) error) error {
d.lock.Lock()
defer d.lock.Unlock()

tx, err := d.db.Tx(ctx)
if err != nil {
return err
}

defer func() {
if v := recover(); v != nil {
if err := tx.Rollback(); err != nil {
panic(fmt.Errorf("rolling back while recovering (%v): %w", v, err))
}

panic(v)
}
}()

if err := fn(ctx, tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rolling back transaction: %w", rerr)
}

return err
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}

return nil
}

func (d *DB) Close() error {
return d.db.Close()
}

func getDatabasePath(dataPath, userID string) string {
return fmt.Sprintf("file:%v?cache=shared&_fk=1", filepath.Join(dataPath, fmt.Sprintf("%v.db", userID)))
}

func NewDB(dataPath, userID string) (*DB, error) {
dbPath := getDatabasePath(dataPath, userID)
client, err := ent.Open(dialect.SQLite, dbPath)

if err != nil {
return nil, err
}

return &DB{db: client}, nil
}
47 changes: 9 additions & 38 deletions internal/backend/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ type user struct {
store store.Store
delimiter string

client *ent.Client
txLock sync.Mutex
db *DB

states map[int]*State
statesLock sync.RWMutex
Expand All @@ -29,8 +28,8 @@ type user struct {
updateWG sync.WaitGroup
}

func newUser(ctx context.Context, userID string, client *ent.Client, remote *remote.User, store store.Store, delimiter string) (*user, error) {
if err := client.Schema.Create(context.Background()); err != nil {
func newUser(ctx context.Context, userID string, db *DB, remote *remote.User, store store.Store, delimiter string) (*user, error) {
if err := db.Init(ctx); err != nil {
return nil, err
}

Expand All @@ -39,7 +38,7 @@ func newUser(ctx context.Context, userID string, client *ent.Client, remote *rem
remote: remote,
store: store,
delimiter: delimiter,
client: client,
db: db,
states: make(map[int]*State),
}

Expand Down Expand Up @@ -70,37 +69,9 @@ func newUser(ctx context.Context, userID string, client *ent.Client, remote *rem

// tx is a helper function that runs a sequence of ent client calls in a transaction.
func (user *user) tx(ctx context.Context, fn func(tx *ent.Tx) error) error {
user.txLock.Lock()
defer user.txLock.Unlock()

tx, err := user.client.Tx(ctx)
if err != nil {
return err
}

defer func() {
if v := recover(); v != nil {
if err := tx.Rollback(); err != nil {
panic(fmt.Errorf("rolling back while recovering (%v): %w", v, err))
}

panic(v)
}
}()

if err := fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rolling back transaction: %w", rerr)
}

return err
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}

return nil
return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
return fn(tx)
})
}

// close closes the backend user.
Expand All @@ -118,8 +89,8 @@ func (user *user) close(ctx context.Context) error {
return fmt.Errorf("failed to close user client storage: %w", err)
}

if err := user.client.Close(); err != nil {
return fmt.Errorf("failed to close user client: %w", err)
if err := user.db.Close(); err != nil {
return fmt.Errorf("failed to close user db: %w", err)
}

return nil
Expand Down
9 changes: 1 addition & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ import (
"strings"
"sync"

"entgo.io/ent/dialect"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/events"
"github.com/ProtonMail/gluon/internal"
"github.com/ProtonMail/gluon/internal/backend"
"github.com/ProtonMail/gluon/internal/backend/ent"
"github.com/ProtonMail/gluon/internal/session"
"github.com/ProtonMail/gluon/profiling"
"github.com/ProtonMail/gluon/store"
Expand Down Expand Up @@ -90,10 +88,6 @@ func New(dir string, withOpt ...Option) (*Server, error) {
return server, nil
}

func getDatabasePath(userPath, userID string) string {
return fmt.Sprintf("file:%v?cache=shared&_fk=1", filepath.Join(userPath, fmt.Sprintf("%v.db", userID)))
}

// AddUser creates a new user and generates new unique ID for this user. If you have an existing userID, please use
// LoadUser instead.
func (s *Server) AddUser(ctx context.Context, conn connector.Connector, encryptionPassphrase []byte) (string, error) {
Expand Down Expand Up @@ -123,8 +117,7 @@ func (s *Server) LoadUser(ctx context.Context, conn connector.Connector, userID
return err
}

source := getDatabasePath(s.dataPath, userID)
client, err := ent.Open(dialect.SQLite, source)
client, err := backend.NewDB(s.dataPath, userID)

if err != nil {
if err := store.Close(); err != nil {
Expand Down

0 comments on commit cc387f0

Please sign in to comment.