Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ban the connect attr interpolateParams for MySQL 8 Vis dbs #5441

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/persistence/sql/sqlplugin/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,14 @@ type (
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
)

func (k DbKind) String() string {
switch k {
case DbKindMain:
return "main"
case DbKindVisibility:
return "visibility"
default:
return "unknown"
}
}
7 changes: 4 additions & 3 deletions common/persistence/sql/sqlplugin/mysql/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (p *plugin) CreateDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.DB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(dbKind, cfg, r)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +67,7 @@ func (p *plugin) CreateAdminDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.AdminDB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(dbKind, cfg, r)
if err != nil {
return nil, err
}
Expand All @@ -80,10 +80,11 @@ func (p *plugin) CreateAdminDB(
// SQL database and the object can be used to perform CRUD operations on
// the tables in the database
func (p *plugin) createDBConnection(
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*sqlx.DB, error) {
mysqlSession, err := session.NewSession(cfg, resolver)
mysqlSession, err := session.NewSession(dbKind, cfg, resolver)
if err != nil {
return nil, err
}
Expand Down
58 changes: 43 additions & 15 deletions common/persistence/sql/sqlplugin/mysql/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package session
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -37,9 +38,14 @@ import (

"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/resolver"
)

type Session struct {
*sqlx.DB
}

const (
driverName = "mysql"

Expand All @@ -48,22 +54,24 @@ const (
defaultIsolationLevel = "'READ-COMMITTED'"
// customTLSName is the name used if a custom tls configuration is created
customTLSName = "tls-custom"
)

var dsnAttrOverrides = map[string]string{
"parseTime": "true",
"clientFoundRows": "true",
}
interpolateParamsAttr = "interpolateParams"
)

type Session struct {
*sqlx.DB
}
var (
errVisInterpolateParamsNotSupported = errors.New("interpolateParams is not supported for mysql visibility stores")
dsnAttrOverrides = map[string]string{
"parseTime": "true",
"clientFoundRows": "true",
}
)

func NewSession(
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*Session, error) {
db, err := createConnection(cfg, resolver)
db, err := createConnection(dbKind, cfg, resolver)
if err != nil {
return nil, err
}
Expand All @@ -77,6 +85,7 @@ func (s *Session) Close() {
}

func createConnection(
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*sqlx.DB, error) {
Expand All @@ -85,7 +94,12 @@ func createConnection(
return nil, err
}

db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver))
dsn, err := buildDSN(dbKind, cfg, resolver)
if err != nil {
return nil, err
}

db, err := sqlx.Connect(driverName, dsn)
if err != nil {
return nil, err
}
Expand All @@ -104,15 +118,23 @@ func createConnection(
return db, nil
}

func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string {
func buildDSN(
dbKind sqlplugin.DbKind,
cfg *config.SQL,
r resolver.ServiceResolver,
) (string, error) {
mysqlConfig := mysql.NewConfig()

mysqlConfig.User = cfg.User
mysqlConfig.Passwd = cfg.Password
mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0]
mysqlConfig.DBName = cfg.DatabaseName
mysqlConfig.Net = cfg.ConnectProtocol
mysqlConfig.Params = buildDSNAttrs(cfg)
var err error
mysqlConfig.Params, err = buildDSNAttrs(dbKind, cfg)
if err != nil {
return "", err
}

// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106
// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189
Expand All @@ -124,10 +146,10 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string {
// https://github.com/temporalio/temporal/issues/1703
mysqlConfig.RejectReadOnly = true

return mysqlConfig.FormatDSN()
return mysqlConfig.FormatDSN(), nil
}

func buildDSNAttrs(cfg *config.SQL) map[string]string {
func buildDSNAttrs(dbKind sqlplugin.DbKind, cfg *config.SQL) (map[string]string, error) {
attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1)
for k, v := range cfg.ConnectAttributes {
k1, v1 := sanitizeAttr(k, v)
Expand All @@ -145,7 +167,13 @@ func buildDSNAttrs(cfg *config.SQL) map[string]string {
attrs[k] = v
}

return attrs
if dbKind == sqlplugin.DbKindVisibility {
if _, ok := attrs[interpolateParamsAttr]; ok {
return nil, errVisInterpolateParamsNotSupported
}
}

return attrs, nil
}

func hasAttr(attrs map[string]string, key string) bool {
Expand Down
65 changes: 49 additions & 16 deletions common/persistence/sql/sqlplugin/mysql/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package session

import (
"fmt"
"net/url"
"strings"
"testing"
Expand All @@ -33,6 +34,7 @@ import (
"github.com/stretchr/testify/suite"

"go.temporal.io/server/common/config"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/resolver"
)

Expand Down Expand Up @@ -66,12 +68,15 @@ func (s *sessionTestSuite) TearDownTest() {

func (s *sessionTestSuite) TestBuildDSN() {
testCases := []struct {
in config.SQL
outURLPath string
outIsolationKey string
outIsolationVal string
name string
in config.SQL
outURLPath string
outIsolationKey string
outIsolationVal string
expectInvalidConfig bool
}{
{
name: "no connect attributes",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -84,6 +89,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "with connect attributes",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -97,6 +103,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (quoted, shorthand)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -110,6 +117,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (unquoted, shorthand)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -123,6 +131,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (unquoted, full name)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -137,21 +146,45 @@ func (s *sessionTestSuite) TestBuildDSN() {
},
}

for _, tc := range testCases {
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr})

out := buildDSN(&tc.in, r)
s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path")
tokens := strings.Split(out, "?")
s.Equal(2, len(tokens), "invalid url")
qry, err := url.Parse("?" + tokens[1])
s.NoError(err)
wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal)
s.Equal(wantAttrs, qry.Query(), "invalid dsn url params")
for _, dbKind := range []sqlplugin.DbKind{sqlplugin.DbKindMain, sqlplugin.DbKindVisibility} {
for _, tc := range testCases {
s.Run(fmt.Sprintf("%s: %s", dbKind.String(), tc.name), func() {
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr})

out, err := buildDSN(dbKind, &tc.in, r)
if tc.expectInvalidConfig {
s.Error(err, "Expected an invalid configuration error")
} else {
s.NoError(err)
}
s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path")
tokens := strings.Split(out, "?")
s.Equal(2, len(tokens), "invalid url")
qry, err := url.Parse("?" + tokens[1])
s.NoError(err)
wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal)
s.Equal(wantAttrs, qry.Query(), "invalid dsn url params")
})
}
}
}

func (s *sessionTestSuite) Test_Visibility_DoesntSupport_interpolateParams() {
cfg := config.SQL{
User: "test",
Password: "pass",
ConnectProtocol: "tcp",
ConnectAddr: "192.168.0.1:3306",
DatabaseName: "db1",
ConnectAttributes: map[string]string{"interpolateParams": "ignored"},
}
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(cfg.ConnectAddr).Return([]string{cfg.ConnectAddr})
_, err := buildDSN(sqlplugin.DbKindVisibility, &cfg, r)
s.Error(err, "We should return an error when a MySQL Visibility database is configured with interpolateParams")
}

func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values {
result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1)
for k, v := range attrs {
Expand Down
Loading