Skip to content

Commit

Permalink
DBPW - Update MSSQL to adhere to v5 Database interface (#10128)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcman312 authored Oct 13, 2020
1 parent dfb0415 commit 21d13e4
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 353 deletions.
209 changes: 76 additions & 133 deletions plugins/database/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,32 @@ import (
"errors"
"fmt"
"strings"
"time"

_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/strutil"
)

const msSQLTypeName = "mssql"

var _ dbplugin.Database = &MSSQL{}
var _ newdbplugin.Database = &MSSQL{}

// MSSQL is an implementation of Database interface
type MSSQL struct {
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}

func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)

return dbType, nil
}
Expand All @@ -42,16 +40,8 @@ func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName

credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 20,
RoleNameLen: 20,
UsernameLen: 128,
Separator: "-",
}

return &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
}

Expand All @@ -62,7 +52,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}

dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))

return nil
}
Expand All @@ -72,6 +62,12 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil
}

func (m *MSSQL) secretValues() map[string]string {
return map[string]string{
m.Password: "[password]",
}
}

func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
Expand All @@ -81,49 +77,51 @@ func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil
}

// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
func (m *MSSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return newdbplugin.InitializeResponse{}, err
}
resp := newdbplugin.InitializeResponse{
Config: newConf,
}
return resp, nil
}

// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the statements provided.
func (m *MSSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
m.Lock()
defer m.Unlock()

statements = dbutil.StatementCompatibilityHelper(statements)

// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}

if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
if len(req.Statements.Commands) == 0 {
return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}

username, err = m.GenerateUsername(usernameConfig)
username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 20),
credsutil.RoleName(req.UsernameConfig.RoleName, 20),
credsutil.MaxLength(128),
credsutil.Separator("-"),
)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}

password, err = m.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")

expirationStr, err := m.GenerateExpiration(expiration)
if err != nil {
return "", "", err
}

// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
defer tx.Rollback()

// Execute each query
for _, stmt := range statements.Creation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
Expand All @@ -132,66 +130,61 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,

m := map[string]string{
"name": username,
"password": password,
"password": req.Password,
"expiration": expirationStr,
}

if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
}
}

// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}

return username, password, nil
}
resp := newdbplugin.NewUserResponse{
Username: username,
}

// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
return resp, nil
}

// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// DeleteUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)

if len(statements.Revocation) == 0 {
return m.revokeUserDefault(ctx, username)
func (m *MSSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
if len(req.Statements.Commands) == 0 {
err := m.revokeUserDefault(ctx, req.Username)
return newdbplugin.DeleteUserResponse{}, err
}

// Get connection
db, err := m.getConnection(ctx)
if err != nil {
return err
return newdbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}

var result error
merr := &multierror.Error{}

// Execute each query
for _, stmt := range statements.Revocation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

m := map[string]string{
"name": username,
"name": req.Username,
}
if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil {
result = multierror.Append(result, err)
merr = multierror.Append(merr, err)
}
}
}

return result
return newdbplugin.DeleteUserResponse{}, merr.ErrorOrNil()
}

func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
Expand Down Expand Up @@ -297,101 +290,51 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
return nil
}

func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()

if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}

rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{alterLoginSQL}
func (m *MSSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
}

db, err := m.getConnection(ctx)
if err != nil {
return nil, err
if req.Password != nil {
err := m.updateUserPass(ctx, req.Username, req.Password)
return newdbplugin.UpdateUserResponse{}, err
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()

password, err := m.GeneratePassword()
if err != nil {
return nil, err
}

for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

m := map[string]string{
"username": m.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}

if err := tx.Commit(); err != nil {
return nil, err
}

if err := db.Close(); err != nil {
return nil, err
}

m.RawConfig["password"] = password
return m.RawConfig, nil
// Expiration is a no-op
return newdbplugin.UpdateUserResponse{}, nil
}

func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{alterLoginSQL}
func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
stmts := changePass.Statements.Commands
if len(stmts) == 0 {
stmts = []string{alterLoginSQL}
}

username = staticUser.Username
password = staticUser.Password
password := changePass.NewPassword

if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
return errors.New("must provide both username and password")
}

m.Lock()
defer m.Unlock()

db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return err
}

var exists bool

err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)

if err != nil && err != sql.ErrNoRows {
return "", "", err
return err
}

stmts := statements.Rotation

// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return err
}

defer func() {
_ = tx.Rollback()
}()
Expand All @@ -409,16 +352,16 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return fmt.Errorf("failed to execute query: %w", err)
}
}
}

if err := tx.Commit(); err != nil {
return "", "", err
return fmt.Errorf("failed to commit transaction: %w", err)
}

return username, password, nil
return nil
}

const dropUserSQL = `
Expand Down
Loading

0 comments on commit 21d13e4

Please sign in to comment.