diff --git a/common/persistence/sql/sqlplugin/interfaces.go b/common/persistence/sql/sqlplugin/interfaces.go index e32993f824c..7fb29c5cf7b 100644 --- a/common/persistence/sql/sqlplugin/interfaces.go +++ b/common/persistence/sql/sqlplugin/interfaces.go @@ -146,3 +146,14 @@ type ( PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) } ) + +func (k DbKind) String() string { + switch k { + case DbKindMain: + return "main" + case DbKindVisibility: + return "visibility" + default: + return "unknown" + } +} diff --git a/common/persistence/sql/sqlplugin/mysql/plugin.go b/common/persistence/sql/sqlplugin/mysql/plugin.go index 5e083b05de3..44493e8b81c 100644 --- a/common/persistence/sql/sqlplugin/mysql/plugin.go +++ b/common/persistence/sql/sqlplugin/mysql/plugin.go @@ -53,7 +53,7 @@ func (p *plugin) CreateDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.DB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(dbKind, cfg, r) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (p *plugin) CreateAdminDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.AdminDB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(dbKind, cfg, r) if err != nil { return nil, err } @@ -80,10 +80,11 @@ func (p *plugin) CreateAdminDB( // SQL database and the object can be used to perform CRUD operations on // the tables in the database func (p *plugin) createDBConnection( + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { - mysqlSession, err := session.NewSession(cfg, resolver) + mysqlSession, err := session.NewSession(dbKind, cfg, resolver) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/mysql/session/session.go b/common/persistence/sql/sqlplugin/mysql/session/session.go index b9490d5043c..535d3973f24 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session.go @@ -27,6 +27,7 @@ package session import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "strings" @@ -37,9 +38,14 @@ import ( "go.temporal.io/server/common/auth" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) +type Session struct { + *sqlx.DB +} + const ( driverName = "mysql" @@ -48,22 +54,24 @@ const ( defaultIsolationLevel = "'READ-COMMITTED'" // customTLSName is the name used if a custom tls configuration is created customTLSName = "tls-custom" -) -var dsnAttrOverrides = map[string]string{ - "parseTime": "true", - "clientFoundRows": "true", -} + interpolateParamsAttr = "interpolateParams" +) -type Session struct { - *sqlx.DB -} +var ( + errVisInterpolateParamsNotSupported = errors.New("interpolateParams is not supported for mysql visibility stores") + dsnAttrOverrides = map[string]string{ + "parseTime": "true", + "clientFoundRows": "true", + } +) func NewSession( + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*Session, error) { - db, err := createConnection(cfg, resolver) + db, err := createConnection(dbKind, cfg, resolver) if err != nil { return nil, err } @@ -77,6 +85,7 @@ func (s *Session) Close() { } func createConnection( + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { @@ -85,7 +94,12 @@ func createConnection( return nil, err } - db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver)) + dsn, err := buildDSN(dbKind, cfg, resolver) + if err != nil { + return nil, err + } + + db, err := sqlx.Connect(driverName, dsn) if err != nil { return nil, err } @@ -104,7 +118,11 @@ func createConnection( return db, nil } -func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { +func buildDSN( + dbKind sqlplugin.DbKind, + cfg *config.SQL, + r resolver.ServiceResolver, +) (string, error) { mysqlConfig := mysql.NewConfig() mysqlConfig.User = cfg.User @@ -112,7 +130,11 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0] mysqlConfig.DBName = cfg.DatabaseName mysqlConfig.Net = cfg.ConnectProtocol - mysqlConfig.Params = buildDSNAttrs(cfg) + var err error + mysqlConfig.Params, err = buildDSNAttrs(dbKind, cfg) + if err != nil { + return "", err + } // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106 // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189 @@ -124,10 +146,10 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { // https://github.com/temporalio/temporal/issues/1703 mysqlConfig.RejectReadOnly = true - return mysqlConfig.FormatDSN() + return mysqlConfig.FormatDSN(), nil } -func buildDSNAttrs(cfg *config.SQL) map[string]string { +func buildDSNAttrs(dbKind sqlplugin.DbKind, cfg *config.SQL) (map[string]string, error) { attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1) for k, v := range cfg.ConnectAttributes { k1, v1 := sanitizeAttr(k, v) @@ -145,7 +167,13 @@ func buildDSNAttrs(cfg *config.SQL) map[string]string { attrs[k] = v } - return attrs + if dbKind == sqlplugin.DbKindVisibility { + if _, ok := attrs[interpolateParamsAttr]; ok { + return nil, errVisInterpolateParamsNotSupported + } + } + + return attrs, nil } func hasAttr(attrs map[string]string, key string) bool { diff --git a/common/persistence/sql/sqlplugin/mysql/session/session_test.go b/common/persistence/sql/sqlplugin/mysql/session/session_test.go index 6343daecba8..9efebb8bba2 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session_test.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session_test.go @@ -25,6 +25,7 @@ package session import ( + "fmt" "net/url" "strings" "testing" @@ -33,6 +34,7 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) @@ -66,12 +68,15 @@ func (s *sessionTestSuite) TearDownTest() { func (s *sessionTestSuite) TestBuildDSN() { testCases := []struct { - in config.SQL - outURLPath string - outIsolationKey string - outIsolationVal string + name string + in config.SQL + outURLPath string + outIsolationKey string + outIsolationVal string + expectInvalidConfig bool }{ { + name: "no connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -84,6 +89,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "with connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -97,6 +103,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (quoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -110,6 +117,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -123,6 +131,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, full name)", in: config.SQL{ User: "test", Password: "pass", @@ -137,21 +146,45 @@ func (s *sessionTestSuite) TestBuildDSN() { }, } - for _, tc := range testCases { - r := resolver.NewMockServiceResolver(s.controller) - r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) - - out := buildDSN(&tc.in, r) - s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") - tokens := strings.Split(out, "?") - s.Equal(2, len(tokens), "invalid url") - qry, err := url.Parse("?" + tokens[1]) - s.NoError(err) - wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) - s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + for _, dbKind := range []sqlplugin.DbKind{sqlplugin.DbKindMain, sqlplugin.DbKindVisibility} { + for _, tc := range testCases { + s.Run(fmt.Sprintf("%s: %s", dbKind.String(), tc.name), func() { + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) + + out, err := buildDSN(dbKind, &tc.in, r) + if tc.expectInvalidConfig { + s.Error(err, "Expected an invalid configuration error") + } else { + s.NoError(err) + } + s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") + tokens := strings.Split(out, "?") + s.Equal(2, len(tokens), "invalid url") + qry, err := url.Parse("?" + tokens[1]) + s.NoError(err) + wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) + s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + }) + } } } +func (s *sessionTestSuite) Test_Visibility_DoesntSupport_interpolateParams() { + cfg := config.SQL{ + User: "test", + Password: "pass", + ConnectProtocol: "tcp", + ConnectAddr: "192.168.0.1:3306", + DatabaseName: "db1", + ConnectAttributes: map[string]string{"interpolateParams": "ignored"}, + } + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(cfg.ConnectAddr).Return([]string{cfg.ConnectAddr}) + _, err := buildDSN(sqlplugin.DbKindVisibility, &cfg, r) + s.Error(err, "We should return an error when a MySQL Visibility database is configured with interpolateParams") +} + func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values { result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1) for k, v := range attrs {