Skip to content

Commit

Permalink
feat: Add multi DB support
Browse files Browse the repository at this point in the history
  • Loading branch information
Neurostep committed Jun 17, 2024
1 parent 3d30182 commit 4ee0ebf
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 55 deletions.
37 changes: 22 additions & 15 deletions pkg/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@ import (
cbuilder "github.com/scribd/go-sdk/internal/pkg/configuration/builder"
)

// Config is the database connection configuration.
type Config struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
Timeout string `mapstructure:"timeout"`
// Connection settings
// TODO Pool field name must be modified in the next major change.
Pool int `mapstructure:"pool"`
MaxOpenConnections int `mapstructure:"max_open_connections"`
ConnectionMaxIdleTime time.Duration `mapstructure:"connection_max_idle_time"`
ConnectionMaxLifetime time.Duration `mapstructure:"connection_max_lifetime"`
}
type (
// Config is the database connection configuration.
Config struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
Timeout string `mapstructure:"timeout"`
// Connection settings
// TODO Pool field name must be modified in the next major change.
Pool int `mapstructure:"pool"`
MaxOpenConnections int `mapstructure:"max_open_connections"`
ConnectionMaxIdleTime time.Duration `mapstructure:"connection_max_idle_time"`
ConnectionMaxLifetime time.Duration `mapstructure:"connection_max_lifetime"`

// Replica is a flag to determine if the connection is a replica.
Replica bool `mapstructure:"replica"`

DBs map[string]Config `mapstructure:"dbs"`
}
)

// NewConfig returns a new Config instance.
func NewConfig() (*Config, error) {
Expand Down
140 changes: 100 additions & 40 deletions pkg/database/config_test.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,127 @@
package database

import (
"os"
"path/filepath"
"runtime"
"testing"
"time"

assert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewConfig(t *testing.T) {
t.Run("RunningInTestEnvironment", func(t *testing.T) {
/*t.Run("RunningInTestEnvironment", func(t *testing.T) {
expected := "test"
actual := os.Getenv("APP_ENV")
assert.Equal(t, expected, actual)
})
})*/

testCases := []struct {
name string
wantError bool
host string
port int
username string
password string
database string
timeout string
pool int
maxOpenConnections int
connectionMaxIdleTime time.Duration
connectionMaxLifetime time.Duration
name string
wantError bool
}{
{
name: "NewWithoutConfigFileFails",
wantError: true,
host: "",
port: 0,
username: "",
password: "",
database: "",
timeout: "",
pool: 0,
maxOpenConnections: 0,
connectionMaxIdleTime: 0,
connectionMaxLifetime: 0,
name: "NewWithoutConfigFileFails",
wantError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, err := NewConfig()
_, err := NewConfig()

gotError := err != nil
assert.Equal(t, gotError, tc.wantError)
})
}
}

func TestNewConfigWithAppRoot(t *testing.T) {
testCases := []struct {
name string
env string
cfg *Config
wantErr bool

envOverrides [][]string
}{
{
name: "NewWithConfigFileWorks",
env: "test",
cfg: &Config{
Host: "mysql",
Port: 3306,
Username: "root",
Password: "",
Database: "test",
Timeout: "1s",
Pool: 5,
DBs: map[string]Config{
"primary_replica": {
Host: "mysql-replica",
Port: 3306,
Username: "root",
Password: "",
Database: "test",
Timeout: "1s",
Pool: 5,
Replica: true,
},
},
},
},
{
name: "NewWithConfigFileWorks, overrides",
env: "test",
cfg: &Config{
Host: "mysql",
Port: 3306,
Username: "root",
Password: "test",
Database: "test",
Timeout: "1s",
Pool: 5,
DBs: map[string]Config{
"primary_replica": {
Host: "mysql-replica",
Port: 3306,
Username: "root",
Password: "test-replica",
Database: "test",
Timeout: "1s",
Pool: 5,
Replica: true,
},
},
},
envOverrides: [][]string{
{"APP_DATABASE_PASSWORD", "test"},
{"APP_DATABASE_DBS_PRIMARY_REPLICA_PASSWORD", "test-replica"},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {

if len(tc.envOverrides) > 0 {
for _, o := range tc.envOverrides {
t.Setenv(o[0], o[1])
}
}

_, filename, _, _ := runtime.Caller(0)
tmpRootParent := filepath.Dir(filename)
t.Setenv("APP_ROOT", filepath.Join(tmpRootParent, "testdata"))

c, err := NewConfig()
if tc.wantErr {
require.NotNil(t, err)
} else {
require.Nil(t, err)
}

assert.Equal(t, c.Host, tc.host)
assert.Equal(t, c.Port, tc.port)
assert.Equal(t, c.Username, tc.username)
assert.Equal(t, c.Password, tc.password)
assert.Equal(t, c.Database, tc.database)
assert.Equal(t, c.Timeout, tc.timeout)
assert.Equal(t, c.Pool, tc.pool)
assert.Equal(t, c.MaxOpenConnections, tc.maxOpenConnections)
assert.Equal(t, c.ConnectionMaxIdleTime, tc.connectionMaxIdleTime)
assert.Equal(t, c.ConnectionMaxLifetime, tc.connectionMaxLifetime)
assert.Equal(t, tc.cfg, c)
})
}
}
21 changes: 21 additions & 0 deletions pkg/database/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
gormtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorm.io/gorm.v1"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/plugin/dbresolver"
)

const testEnv = "test"
Expand All @@ -22,6 +23,13 @@ func NewConnection(config *Config, environment, appName string) (*gorm.DB, error
if err != nil {
return nil, err
}
if len(config.DBs) > 0 {
if err := db.Use(dbresolver.Register(
getDbResolverConfig(config, environment),
)); err != nil {
return nil, err
}
}

if err := databasePoolSettings(db, config); err != nil {
return nil, err
Expand All @@ -30,6 +38,19 @@ func NewConnection(config *Config, environment, appName string) (*gorm.DB, error
return db, nil
}

func getDbResolverConfig(config *Config, env string) dbresolver.Config {
resolverCfg := dbresolver.Config{}
for _, dbConfig := range config.DBs {
if dbConfig.Replica {
resolverCfg.Replicas = []gorm.Dialector{getDialectorFromConfig(&dbConfig, env)}
} else {
resolverCfg.Sources = []gorm.Dialector{getDialectorFromConfig(&dbConfig, env)}
}
}

return resolverCfg
}

func getDialectorFromConfig(config *Config, environment string) gorm.Dialector {
connectionDetails := NewConnectionDetails(config)

Expand Down
30 changes: 30 additions & 0 deletions pkg/database/testdata/config/database.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
common: &common
host: mysql
port: 3306
username: root
password:
timeout: 1s
pool: 5
max_open_connections: 0
connection_max_idle_time: 0s
connection_max_lifetime: 0s

test: &test
<<: *common
database: test
dbs:
primary_replica:
database: test
replica: true
host: mysql-replica
port: 3306
username: root
password:
timeout: 1s
pool: 5
max_open_connections: 0
connection_max_idle_time: 0s
connection_max_lifetime: 0s

development:
<<: *test

0 comments on commit 4ee0ebf

Please sign in to comment.