From 29b888dc43e6ced5034633c7b2bf350210ff5262 Mon Sep 17 00:00:00 2001 From: akeemphilbert Date: Fri, 13 Jan 2023 04:09:12 -0400 Subject: [PATCH] feature: #227 Added tests for the IAM functionality * Updated the signature for SQLConnectionFromConfig to return the connection string to make it more testable * Added two variable to the DB config, one for specifying that IAM is to be used and the other for the AWS Region * Added InvalidAWSDriver error to be used when an incompatible driver is specified --- controllers/rest/api.go | 24 +++++----- controllers/rest/api_test.go | 63 ++++++++++++++++++++++++- controllers/rest/global_initializers.go | 4 +- model/module_test.go | 6 +-- model/service.go | 20 ++++---- 5 files changed, 91 insertions(+), 26 deletions(-) diff --git a/controllers/rest/api.go b/controllers/rest/api.go index 594f7452..3b6a17b6 100644 --- a/controllers/rest/api.go +++ b/controllers/rest/api.go @@ -31,6 +31,8 @@ import ( "github.com/wepala/weos/projections" ) +var InvalidAWSDriver = errors.New("invalid aws driver specified, must be postgres or mysql") + //RESTAPI is used to manage the API type RESTAPI struct { Application model.Service @@ -562,7 +564,7 @@ func (p *RESTAPI) Initialize(ctxt context.Context) error { } //SQLConnectionFromConfig get db connection based on a Config -func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gorm.DB, error) { +func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gorm.DB, string, error) { var connStr string var err error @@ -574,7 +576,7 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor if _, err = os.Stat(config.Database); os.IsNotExist(err) { _, err = os.Create(strings.Replace(config.Database, ":memory:", "", -1)) if err != nil { - return nil, nil, model.NewError(fmt.Sprintf("error creating sqlite database '%s'", config.Database), err) + return nil, nil, "", model.NewError(fmt.Sprintf("error creating sqlite database '%s'", config.Database), err) } } } @@ -606,12 +608,12 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor connStr = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", config.Host, strconv.Itoa(config.Port), config.User, config.Password, config.Database) default: - return nil, nil, errors.New(fmt.Sprintf("db driver '%s' is not supported ", config.Driver)) + return nil, nil, connStr, errors.New(fmt.Sprintf("db driver '%s' is not supported ", config.Driver)) } db, err := sql.Open(config.Driver, connStr) if err != nil { - return nil, nil, errors.New(fmt.Sprintf("error setting up connection to database '%s' with connection '%s'", err, connStr)) + return nil, nil, connStr, errors.New(fmt.Sprintf("error setting up connection to database '%s' with connection '%s'", err, connStr)) } db.SetMaxOpenConns(config.MaxOpen) @@ -625,7 +627,7 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor Conn: db, }), nil) if err != nil { - return nil, nil, err + return nil, nil, connStr, err } case "sqlite3": gormDB, err = gorm.Open(&dialects.SQLite{ @@ -634,14 +636,14 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor }, }, nil) if err != nil { - return nil, nil, err + return nil, nil, connStr, err } case "mysql": gormDB, err = gorm.Open(dialects.NewMySQL(mysql.Config{ Conn: db, }), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}) if err != nil { - return nil, nil, err + return nil, nil, connStr, err } case "ramsql": //this is for testing gormDB = &gorm.DB{} @@ -650,17 +652,17 @@ func (p *RESTAPI) SQLConnectionFromConfig(config *model.DBConfig) (*sql.DB, *gor Conn: db, }), nil) if err != nil { - return nil, nil, err + return nil, nil, connStr, err } case "clickhouse": gormDB, err = gorm.Open(clickhouse.New(clickhouse.Config{ Conn: db, }), nil) if err != nil { - return nil, nil, err + return nil, nil, connStr, err } default: - return nil, nil, errors.New(fmt.Sprintf("we don't support database driver '%s'", config.Driver)) + return nil, nil, connStr, errors.New(fmt.Sprintf("we don't support database driver '%s'", config.Driver)) } - return db, gormDB, err + return db, gormDB, connStr, err } diff --git a/controllers/rest/api_test.go b/controllers/rest/api_test.go index 3fb6ac64..244ebf5d 100644 --- a/controllers/rest/api_test.go +++ b/controllers/rest/api_test.go @@ -3,6 +3,7 @@ package rest_test import ( "bytes" "encoding/json" + "errors" "github.com/labstack/echo/v4" "io/ioutil" "net/http" @@ -294,7 +295,7 @@ func TestRESTAPI_DefaultProjectionRegisteredBefore(t *testing.T) { if err != nil { t.Fatalf("un expected error loading spec '%s'", err) } - _, gormDB, err := tapi.SQLConnectionFromConfig(tapi.Config.Database) + _, gormDB, _, err := tapi.SQLConnectionFromConfig(tapi.Config.Database) gormProjection, err := projections.NewProjection(context.TODO(), gormDB, tapi.EchoInstance().Logger) if err != nil { t.Fatalf("error setting up gorm projection") @@ -325,6 +326,66 @@ func TestRESTAPI_DefaultProjectionRegisteredBefore(t *testing.T) { } } +func TestRESTAPI_SQLConnectionFromConfig(t *testing.T) { + t.Run("test with valid config", func(t *testing.T) { + apiYaml := `openapi: 3.0.3 +info: + title: Blog + description: Blog example + version: 1.0.0 +servers: + - url: https://prod1.weos.sh/blog/dev + description: WeOS Dev + - url: https://prod1.weos.sh/blog/v1 + - url: http://localhost:8681 +x-weos-config: + databases: + - name: Default + driver: postgres + password: test-password + aws-iam: true + aws-region: us-east-1 +` + tapi, err := api.New(apiYaml) + if err != nil { + t.Fatalf("un expected error loading spec '%s'", err) + } + var connectionString string + _, _, connectionString, err = tapi.SQLConnectionFromConfig(tapi.Config.Databases[0]) + if strings.Contains(connectionString, "test-password") { + t.Errorf("expected the connection string to not contain password '%s', '%s'", "test-password", connectionString) + } + }) + t.Run("unsupported driver", func(t *testing.T) { + apiYaml := `openapi: 3.0.3 +info: + title: Blog + description: Blog example + version: 1.0.0 +servers: + - url: https://prod1.weos.sh/blog/dev + description: WeOS Dev + - url: https://prod1.weos.sh/blog/v1 + - url: http://localhost:8681 +x-weos-config: + databases: + - name: Default + driver: sqlite3 + password: test-password + aws-iam: true + aws-region: us-east-1 +` + tapi, err := api.New(apiYaml) + if err != nil { + t.Fatalf("un expected error loading spec '%s'", err) + } + _, _, _, err = tapi.SQLConnectionFromConfig(tapi.Config.Databases[0]) + if !errors.Is(err, api.InvalidAWSDriver) { + t.Errorf("expected the error to be '%s', got '%s'", api.InvalidAWSDriver, err) + } + }) +} + func TestRESTAPI_Initialize_DiscoveryAddedToGet(t *testing.T) { os.Remove("test.db") tapi, err := api.New("./fixtures/blog.yaml") diff --git a/controllers/rest/global_initializers.go b/controllers/rest/global_initializers.go index 74d6acd7..7552b6aa 100644 --- a/controllers/rest/global_initializers.go +++ b/controllers/rest/global_initializers.go @@ -56,7 +56,7 @@ func SQLDatabase(ctxt context.Context, tapi Container, swagger *openapi3.Swagger if config.ServiceConfig != nil && config.ServiceConfig.Database != nil { var connection *sql.DB var gormDB *gorm.DB - if connection, gormDB, err = api.SQLConnectionFromConfig(config.Database); err == nil { + if connection, gormDB, _, err = api.SQLConnectionFromConfig(config.Database); err == nil { api.RegisterDBConnection("Default", connection) api.RegisterGORMDB("Default", gormDB) } @@ -66,7 +66,7 @@ func SQLDatabase(ctxt context.Context, tapi Container, swagger *openapi3.Swagger for _, dbconfig := range config.ServiceConfig.Databases { var connection *sql.DB var gormDB *gorm.DB - if connection, gormDB, err = api.SQLConnectionFromConfig(dbconfig); err == nil { + if connection, gormDB, _, err = api.SQLConnectionFromConfig(dbconfig); err == nil { api.RegisterDBConnection(dbconfig.Name, connection) api.RegisterGORMDB(dbconfig.Name, gormDB) } diff --git a/model/module_test.go b/model/module_test.go index 5533a71d..c2ebf060 100644 --- a/model/module_test.go +++ b/model/module_test.go @@ -115,7 +115,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) { } api := &rest.RESTAPI{} - db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) + db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) if err != nil { t.Fatalf("unexpected error getting connection '%s'", err) } @@ -148,7 +148,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) { } api := &rest.RESTAPI{} - db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) + db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) if err != nil { t.Fatalf("unexpected error getting connection '%s'", err) } @@ -186,7 +186,7 @@ func TestNewApplicationFromConfig_SQLite(t *testing.T) { } api := &rest.RESTAPI{} - db, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) + db, _, _, err := api.SQLConnectionFromConfig(sqliteConfig.Database) if err != nil { t.Fatalf("unexpected error getting connection '%s'", err) } diff --git a/model/service.go b/model/service.go index 1f1227dc..01660631 100644 --- a/model/service.go +++ b/model/service.go @@ -34,15 +34,17 @@ type ServiceConfig struct { } type DBConfig struct { - Name string `json:"name"` - Host string `json:"host"` - User string `json:"username"` - Password string `json:"password"` - Port int `json:"port"` - Database string `json:"database"` - Driver string `json:"driver"` - MaxOpen int `json:"max-open"` - MaxIdle int `json:"max-idle"` + Name string `json:"name"` + Host string `json:"host"` + User string `json:"username"` + Password string `json:"password"` + Port int `json:"port"` + Database string `json:"database"` + Driver string `json:"driver"` + MaxOpen int `json:"max-open"` + MaxIdle int `json:"max-idle"` + AwsIam bool `json:"aws-iam"` + AwsRegion string `json:"aws-region"` } type LogConfig struct {