Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DBPW - Update MSSQL to adhere to v5 Database interface #10128

Merged
merged 3 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 76 additions & 133 deletions plugins/database/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,32 @@ import (
"errors"
"fmt"
"strings"
"time"

_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/strutil"
)

const msSQLTypeName = "mssql"

var _ dbplugin.Database = &MSSQL{}
var _ newdbplugin.Database = &MSSQL{}

// MSSQL is an implementation of Database interface
type MSSQL struct {
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}

func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)

return dbType, nil
}
Expand All @@ -42,16 +40,8 @@ func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName

credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 20,
RoleNameLen: 20,
UsernameLen: 128,
Separator: "-",
}

return &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
}

Expand All @@ -62,7 +52,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}

dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))

return nil
}
Expand All @@ -72,6 +62,12 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil
}

func (m *MSSQL) secretValues() map[string]string {
return map[string]string{
m.Password: "[password]",
}
}

func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
Expand All @@ -81,49 +77,51 @@ func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil
}

// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
func (m *MSSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return newdbplugin.InitializeResponse{}, err
}
resp := newdbplugin.InitializeResponse{
Config: newConf,
}
return resp, nil
}

// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the statements provided.
func (m *MSSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
m.Lock()
defer m.Unlock()

statements = dbutil.StatementCompatibilityHelper(statements)

// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}

if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
if len(req.Statements.Commands) == 0 {
return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}

username, err = m.GenerateUsername(usernameConfig)
username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 20),
credsutil.RoleName(req.UsernameConfig.RoleName, 20),
credsutil.MaxLength(128),
credsutil.Separator("-"),
)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}

password, err = m.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")

expirationStr, err := m.GenerateExpiration(expiration)
if err != nil {
return "", "", err
}

// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
defer tx.Rollback()

// Execute each query
for _, stmt := range statements.Creation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
Expand All @@ -132,66 +130,61 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,

m := map[string]string{
"name": username,
"password": password,
"password": req.Password,
"expiration": expirationStr,
}

if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
}
}

// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}

return username, password, nil
}
resp := newdbplugin.NewUserResponse{
Username: username,
}

// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
return resp, nil
}

// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// DeleteUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)

if len(statements.Revocation) == 0 {
return m.revokeUserDefault(ctx, username)
func (m *MSSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
if len(req.Statements.Commands) == 0 {
err := m.revokeUserDefault(ctx, req.Username)
return newdbplugin.DeleteUserResponse{}, err
}

// Get connection
db, err := m.getConnection(ctx)
if err != nil {
return err
return newdbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}

var result error
merr := &multierror.Error{}

// Execute each query
for _, stmt := range statements.Revocation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

m := map[string]string{
"name": username,
"name": req.Username,
}
if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil {
result = multierror.Append(result, err)
merr = multierror.Append(merr, err)
}
}
}

return result
return newdbplugin.DeleteUserResponse{}, merr.ErrorOrNil()
}

func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
Expand Down Expand Up @@ -297,101 +290,51 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
return nil
}

func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()

if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}

rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{alterLoginSQL}
func (m *MSSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
}

db, err := m.getConnection(ctx)
if err != nil {
return nil, err
if req.Password != nil {
err := m.updateUserPass(ctx, req.Username, req.Password)
return newdbplugin.UpdateUserResponse{}, err
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()

password, err := m.GeneratePassword()
if err != nil {
return nil, err
}

for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

m := map[string]string{
"username": m.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}

if err := tx.Commit(); err != nil {
return nil, err
}

if err := db.Close(); err != nil {
return nil, err
}

m.RawConfig["password"] = password
return m.RawConfig, nil
// Expiration is a no-op
return newdbplugin.UpdateUserResponse{}, nil
}

func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{alterLoginSQL}
func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
stmts := changePass.Statements.Commands
if len(stmts) == 0 {
stmts = []string{alterLoginSQL}
}

username = staticUser.Username
password = staticUser.Password
password := changePass.NewPassword

if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
return errors.New("must provide both username and password")
}

m.Lock()
defer m.Unlock()

db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return err
}

var exists bool

err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)

if err != nil && err != sql.ErrNoRows {
return "", "", err
return err
}

stmts := statements.Rotation

// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return err
}

defer func() {
_ = tx.Rollback()
}()
Expand All @@ -409,16 +352,16 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return fmt.Errorf("failed to execute query: %w", err)
}
}
}

if err := tx.Commit(); err != nil {
return "", "", err
return fmt.Errorf("failed to commit transaction: %w", err)
}

return username, password, nil
return nil
}

const dropUserSQL = `
Expand Down
Loading