Skip to content

Commit

Permalink
feature: #227 Added tests for the IAM functionality
Browse files Browse the repository at this point in the history
* Updated the signature for SQLConnectionFromConfig to return the connection string to make it more testable
* Added two variable to the DB config, one for specifying that IAM is to be used and the other for the AWS Region
* Added InvalidAWSDriver error to be used when an incompatible driver is specified
  • Loading branch information
akeemphilbert committed Jan 13, 2023
1 parent fde552b commit 29b888d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 26 deletions.
24 changes: 13 additions & 11 deletions controllers/rest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"github.com/wepala/weos/projections"
)

var InvalidAWSDriver = errors.New("invalid aws driver specified, must be postgres or mysql")

//RESTAPI is used to manage the API
type RESTAPI struct {
Application model.Service
Expand Down Expand Up @@ -562,7 +564,7 @@ func (p *RESTAPI) Initialize(ctxt context.Context) error {
}

//SQLConnectionFromConfig get db connection based on a Config
func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gorm.DB, error) {
func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gorm.DB, string, error) {
var connStr string
var err error

Expand All @@ -574,7 +576,7 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor
if _, err = os.Stat(config.Database); os.IsNotExist(err) {
_, err = os.Create(strings.Replace(config.Database, ":memory:", "", -1))
if err != nil {
return nil, nil, model.NewError(fmt.Sprintf("error creating sqlite database '%s'", config.Database), err)
return nil, nil, "", model.NewError(fmt.Sprintf("error creating sqlite database '%s'", config.Database), err)
}
}
}
Expand Down Expand Up @@ -606,12 +608,12 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor
connStr = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
config.Host, strconv.Itoa(config.Port), config.User, config.Password, config.Database)
default:
return nil, nil, errors.New(fmt.Sprintf("db driver '%s' is not supported ", config.Driver))
return nil, nil, connStr, errors.New(fmt.Sprintf("db driver '%s' is not supported ", config.Driver))
}

db, err := sql.Open(config.Driver, connStr)
if err != nil {
return nil, nil, errors.New(fmt.Sprintf("error setting up connection to database '%s' with connection '%s'", err, connStr))
return nil, nil, connStr, errors.New(fmt.Sprintf("error setting up connection to database '%s' with connection '%s'", err, connStr))
}

db.SetMaxOpenConns(config.MaxOpen)
Expand All @@ -625,7 +627,7 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor
Conn: db,
}), nil)
if err != nil {
return nil, nil, err
return nil, nil, connStr, err
}
case "sqlite3":
gormDB, err = gorm.Open(&dialects.SQLite{
Expand All @@ -634,14 +636,14 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor
},
}, nil)
if err != nil {
return nil, nil, err
return nil, nil, connStr, err
}
case "mysql":
gormDB, err = gorm.Open(dialects.NewMySQL(mysql.Config{
Conn: db,
}), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true})
if err != nil {
return nil, nil, err
return nil, nil, connStr, err
}
case "ramsql": //this is for testing
gormDB = &gorm.DB{}
Expand All @@ -650,17 +652,17 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor
Conn: db,
}), nil)
if err != nil {
return nil, nil, err
return nil, nil, connStr, err
}
case "clickhouse":
gormDB, err = gorm.Open(clickhouse.New(clickhouse.Config{
Conn: db,
}), nil)
if err != nil {
return nil, nil, err
return nil, nil, connStr, err
}
default:
return nil, nil, errors.New(fmt.Sprintf("we don't support database driver '%s'", config.Driver))
return nil, nil, connStr, errors.New(fmt.Sprintf("we don't support database driver '%s'", config.Driver))
}
return db, gormDB, err
return db, gormDB, connStr, err
}
63 changes: 62 additions & 1 deletion controllers/rest/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rest_test
import (
"bytes"
"encoding/json"
"errors"
"github.com/labstack/echo/v4"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -294,7 +295,7 @@ func TestRESTAPI_DefaultProjectionRegisteredBefore(t *testing.T) {
if err != nil {
t.Fatalf("un expected error loading spec '%s'", err)
}
_, gormDB, err := tapi.SQLConnectionFromConfig(tapi.Config.Database)
_, gormDB, _, err := tapi.SQLConnectionFromConfig(tapi.Config.Database)
gormProjection, err := projections.NewProjection(context.TODO(), gormDB, tapi.EchoInstance().Logger)
if err != nil {
t.Fatalf("error setting up gorm projection")
Expand Down Expand Up @@ -325,6 +326,66 @@ func TestRESTAPI_DefaultProjectionRegisteredBefore(t *testing.T) {
}
}

func TestRESTAPI_SQLConnectionFromConfig(t *testing.T) {
t.Run("test with valid config", func(t *testing.T) {
apiYaml := `openapi: 3.0.3
info:
title: Blog
description: Blog example
version: 1.0.0
servers:
- url: https://prod1.weos.sh/blog/dev
description: WeOS Dev
- url: https://prod1.weos.sh/blog/v1
- url: http://localhost:8681
x-weos-config:
databases:
- name: Default
driver: postgres
password: test-password
aws-iam: true
aws-region: us-east-1
`
tapi, err := api.New(apiYaml)
if err != nil {
t.Fatalf("un expected error loading spec '%s'", err)
}
var connectionString string
_, _, connectionString, err = tapi.SQLConnectionFromConfig(tapi.Config.Databases[0])
if strings.Contains(connectionString, "test-password") {
t.Errorf("expected the connection string to not contain password '%s', '%s'", "test-password", connectionString)
}
})
t.Run("unsupported driver", func(t *testing.T) {
apiYaml := `openapi: 3.0.3
info:
title: Blog
description: Blog example
version: 1.0.0
servers:
- url: https://prod1.weos.sh/blog/dev
description: WeOS Dev
- url: https://prod1.weos.sh/blog/v1
- url: http://localhost:8681
x-weos-config:
databases:
- name: Default
driver: sqlite3
password: test-password
aws-iam: true
aws-region: us-east-1
`
tapi, err := api.New(apiYaml)
if err != nil {
t.Fatalf("un expected error loading spec '%s'", err)
}
_, _, _, err = tapi.SQLConnectionFromConfig(tapi.Config.Databases[0])
if !errors.Is(err, api.InvalidAWSDriver) {
t.Errorf("expected the error to be '%s', got '%s'", api.InvalidAWSDriver, err)
}
})
}

func TestRESTAPI_Initialize_DiscoveryAddedToGet(t *testing.T) {
os.Remove("test.db")
tapi, err := api.New("./fixtures/blog.yaml")
Expand Down
4 changes: 2 additions & 2 deletions controllers/rest/global_initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func SQLDatabase(ctxt context.Context, tapi Container, swagger *openapi3.Swagger
if config.ServiceConfig != nil && config.ServiceConfig.Database != nil {
var connection *sql.DB
var gormDB *gorm.DB
if connection, gormDB, err = api.SQLConnectionFromConfig(config.Database); err == nil {
if connection, gormDB, _, err = api.SQLConnectionFromConfig(config.Database); err == nil {
api.RegisterDBConnection("Default", connection)
api.RegisterGORMDB("Default", gormDB)
}
Expand All @@ -66,7 +66,7 @@ func SQLDatabase(ctxt context.Context, tapi Container, swagger *openapi3.Swagger
for _, dbconfig := range config.ServiceConfig.Databases {
var connection *sql.DB
var gormDB *gorm.DB
if connection, gormDB, err = api.SQLConnectionFromConfig(dbconfig); err == nil {
if connection, gormDB, _, err = api.SQLConnectionFromConfig(dbconfig); err == nil {
api.RegisterDBConnection(dbconfig.Name, connection)
api.RegisterGORMDB(dbconfig.Name, gormDB)
}
Expand Down
6 changes: 3 additions & 3 deletions model/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) {
}

api := &rest.RESTAPI{}
db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
if err != nil {
t.Fatalf("unexpected error getting connection '%s'", err)
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) {
}

api := &rest.RESTAPI{}
db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
if err != nil {
t.Fatalf("unexpected error getting connection '%s'", err)
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) {
}

api := &rest.RESTAPI{}
db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database)
if err != nil {
t.Fatalf("unexpected error getting connection '%s'", err)
}
Expand Down
20 changes: 11 additions & 9 deletions model/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@ type ServiceConfig struct {
}

type DBConfig struct {
Name string `json:"name"`
Host string `json:"host"`
User string `json:"username"`
Password string `json:"password"`
Port int `json:"port"`
Database string `json:"database"`
Driver string `json:"driver"`
MaxOpen int `json:"max-open"`
MaxIdle int `json:"max-idle"`
Name string `json:"name"`
Host string `json:"host"`
User string `json:"username"`
Password string `json:"password"`
Port int `json:"port"`
Database string `json:"database"`
Driver string `json:"driver"`
MaxOpen int `json:"max-open"`
MaxIdle int `json:"max-idle"`
AwsIam bool `json:"aws-iam"`
AwsRegion string `json:"aws-region"`
}

type LogConfig struct {
Expand Down

0 comments on commit 29b888d

Please sign in to comment.