Skip to content

Commit

Permalink
feat: add db.ignore_unknown_table_columns configuration property (#…
Browse files Browse the repository at this point in the history
…3192) (#3193)

The property allows to ignore scan errors when columns in the SQL result have no fields in the destination struct.
  • Loading branch information
mih-kopylov authored Oct 4, 2022
1 parent 3ba28f2 commit 5842946
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 28 deletions.
47 changes: 35 additions & 12 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ const (
KeyAdminURL = "urls.self.admin"
KeyIssuerURL = "urls.self.issuer"
KeyAccessTokenStrategy = "strategies.access_token"
KeyDbIgnoreUnknownTableColumns = "db.ignore_unknown_table_columns"
KeySubjectIdentifierAlgorithmSalt = "oidc.subject_identifiers.pairwise.salt"
KeyPublicAllowDynamicRegistration = "oidc.dynamic_client_registration.enabled"
KeyPKCEEnforced = "oauth2.pkce.enforced"
Expand Down Expand Up @@ -144,12 +145,14 @@ func (p *DefaultProvider) getProvider(ctx context.Context) *configx.Provider {
}

func New(ctx context.Context, l *logrusx.Logger, opts ...configx.OptionModifier) (*DefaultProvider, error) {
opts = append([]configx.OptionModifier{
configx.WithStderrValidationReporter(),
configx.OmitKeysFromTracing("dsn", "secrets.system", "secrets.cookie"),
configx.WithImmutables("log", "serve", "dsn", "profiling"),
configx.WithLogrusWatcher(l),
}, opts...)
opts = append(
[]configx.OptionModifier{
configx.WithStderrValidationReporter(),
configx.OmitKeysFromTracing("dsn", "secrets.system", "secrets.cookie"),
configx.WithImmutables("log", "serve", "dsn", "profiling"),
configx.WithLogrusWatcher(l),
}, opts...,
)

p, err := configx.New(ctx, spec.ConfigValidationSchema, opts...)
if err != nil {
Expand Down Expand Up @@ -217,13 +220,17 @@ func (p *DefaultProvider) SubjectTypesSupported(ctx context.Context) []string {
if stringslice.Has(types, "pairwise") {
if p.AccessTokenStrategy(ctx) == AccessTokenJWTStrategy {
p.l.Warn(`The pairwise subject identifier algorithm is not supported by the JWT OAuth 2.0 Access Token Strategy and is thus being disabled. Please remove "pairwise" from oidc.subject_identifiers.supported_types" (e.g. oidc.subject_identifiers.supported_types=public) or set strategies.access_token to "opaque".`)
types = stringslice.Filter(types,
types = stringslice.Filter(
types,
func(s string) bool {
return !(s == "public")
},
)
} else if len(p.SubjectIdentifierAlgorithmSalt(ctx)) < 8 {
p.l.Fatalf(`The pairwise subject identifier algorithm was set but length of oidc.subject_identifier.salt is too small (%d < 8), please set oidc.subject_identifiers.pairwise.salt to a random string with 8 characters or more.`, len(p.SubjectIdentifierAlgorithmSalt(ctx)))
p.l.Fatalf(
`The pairwise subject identifier algorithm was set but length of oidc.subject_identifier.salt is too small (%d < 8), please set oidc.subject_identifiers.pairwise.salt to a random string with 8 characters or more.`,
len(p.SubjectIdentifierAlgorithmSalt(ctx)),
)
}
}

Expand Down Expand Up @@ -317,7 +324,11 @@ func (p *DefaultProvider) GetCookieSecrets(ctx context.Context) [][]byte {
}

func (p *DefaultProvider) LogoutRedirectURL(ctx context.Context) *url.URL {
return urlRoot(p.getProvider(ctx).RequestURIF(KeyLogoutRedirectURL, p.publicFallbackURL(ctx, "oauth2/fallbacks/logout/callback")))
return urlRoot(
p.getProvider(ctx).RequestURIF(
KeyLogoutRedirectURL, p.publicFallbackURL(ctx, "oauth2/fallbacks/logout/callback"),
),
)
}

func (p *DefaultProvider) publicFallbackURL(ctx context.Context, path string) *url.URL {
Expand Down Expand Up @@ -361,11 +372,17 @@ func (p *DefaultProvider) PublicURL(ctx context.Context) *url.URL {
}

func (p *DefaultProvider) AdminURL(ctx context.Context) *url.URL {
return urlRoot(p.getProvider(ctx).RequestURIF(KeyAdminURL, p.fallbackURL(ctx, "/", p.host(AdminInterface), p.port(AdminInterface))))
return urlRoot(
p.getProvider(ctx).RequestURIF(
KeyAdminURL, p.fallbackURL(ctx, "/", p.host(AdminInterface), p.port(AdminInterface)),
),
)
}

func (p *DefaultProvider) IssuerURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyIssuerURL, p.fallbackURL(ctx, "/", p.host(PublicInterface), p.port(PublicInterface)))
return p.getProvider(ctx).RequestURIF(
KeyIssuerURL, p.fallbackURL(ctx, "/", p.host(PublicInterface), p.port(PublicInterface)),
)
}

func (p *DefaultProvider) OAuth2ClientRegistrationURL(ctx context.Context) *url.URL {
Expand Down Expand Up @@ -402,6 +419,10 @@ func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil)
}

func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool {
return p.p.Bool(KeyDbIgnoreUnknownTableColumns)
}

func (p *DefaultProvider) SubjectIdentifierAlgorithmSalt(ctx context.Context) string {
return p.getProvider(ctx).String(KeySubjectIdentifierAlgorithmSalt)
}
Expand All @@ -425,7 +446,9 @@ func (p *DefaultProvider) OIDCDiscoverySupportedScope(ctx context.Context) []str
}

func (p *DefaultProvider) OIDCDiscoveryUserinfoEndpoint(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyOIDCDiscoveryUserinfoEndpoint, urlx.AppendPaths(p.PublicURL(ctx), "/userinfo"))
return p.getProvider(ctx).RequestURIF(
KeyOIDCDiscoveryUserinfoEndpoint, urlx.AppendPaths(p.PublicURL(ctx), "/userinfo"),
)
}

func (p *DefaultProvider) GetSendDebugMessagesToClients(ctx context.Context) bool {
Expand Down
37 changes: 23 additions & 14 deletions driver/registry_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ var defaultInitialPing = func(m *RegistrySQL) error {
}

func init() {
dbal.RegisterDriver(func() dbal.Driver {
return NewRegistrySQL()
})
dbal.RegisterDriver(
func() dbal.Driver {
return NewRegistrySQL()
},
)
}

func NewRegistrySQL() *RegistrySQL {
Expand Down Expand Up @@ -88,7 +90,9 @@ func (m *RegistrySQL) determineNetwork(c *pop.Connection, ctx context.Context) (
return networkx.NewManager(c, m.Logger(), m.Tracer(ctx)).Determine(ctx)
}

func (m *RegistrySQL) Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer) error {
func (m *RegistrySQL) Init(
ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer,
) error {
if m.persister == nil {
m.WithContextualizer(ctxer)
var opts []instrumentedsql.Opt
Expand All @@ -99,16 +103,21 @@ func (m *RegistrySQL) Init(ctx context.Context, skipNetworkInit bool, migrate bo
}

// new db connection
pool, idlePool, connMaxLifetime, connMaxIdleTime, cleanedDSN := sqlcon.ParseConnectionOptions(m.l, m.Config().DSN())
c, err := pop.NewConnection(&pop.ConnectionDetails{
URL: sqlcon.FinalizeDSN(m.l, cleanedDSN),
IdlePool: idlePool,
ConnMaxLifetime: connMaxLifetime,
ConnMaxIdleTime: connMaxIdleTime,
Pool: pool,
UseInstrumentedDriver: m.Tracer(ctx).IsLoaded(),
InstrumentedDriverOptions: opts,
})
pool, idlePool, connMaxLifetime, connMaxIdleTime, cleanedDSN := sqlcon.ParseConnectionOptions(
m.l, m.Config().DSN(),
)
c, err := pop.NewConnection(
&pop.ConnectionDetails{
URL: sqlcon.FinalizeDSN(m.l, cleanedDSN),
IdlePool: idlePool,
ConnMaxLifetime: connMaxLifetime,
ConnMaxIdleTime: connMaxIdleTime,
Pool: pool,
UseInstrumentedDriver: m.Tracer(ctx).IsLoaded(),
InstrumentedDriverOptions: opts,
Unsafe: m.Config().DbIgnoreUnknownTableColumns(),
},
)
if err != nil {
return errorsx.WithStack(err)
}
Expand Down
55 changes: 53 additions & 2 deletions driver/registry_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ package driver

import (
"context"
"math/rand"
"strconv"
"testing"

"github.com/stretchr/testify/assert"

"github.com/ory/x/errorsx"

"github.com/ory/hydra/client"
"github.com/ory/hydra/driver/config"
"github.com/ory/hydra/persistence/sql"
"github.com/ory/x/configx"
"github.com/ory/x/contextx"
"github.com/ory/x/errorsx"
"github.com/ory/x/logrusx"
"github.com/ory/x/sqlcon/dockertest"
)

func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
Expand All @@ -31,6 +34,54 @@ func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager())
}

func TestDbUnknownTableColumns(t *testing.T) {
tests := []struct {
name string
flagValue string
expectError bool
expectedSize int
}{
{name: "with unsafe", flagValue: "true", expectError: false, expectedSize: 1},
{name: "without unsafe", flagValue: "false", expectError: true, expectedSize: 0},
}

for _, test := range tests {
t.Run(
test.name, func(t *testing.T) {
ctx := context.Background()
l := logrusx.New("", "")
c := config.MustNew(ctx, l, configx.SkipValidation())
postgresDsn := dockertest.RunTestPostgreSQL(t)
c.MustSet(ctx, config.KeyDSN, postgresDsn)
c.MustSet(ctx, config.KeyDbIgnoreUnknownTableColumns, test.flagValue)
reg, err := NewRegistryFromDSN(ctx, c, l, false, true, &contextx.Default{})
assert.NoError(t, err)

statement := "ALTER TABLE \"hydra_client\" ADD COLUMN \"temp_column\" VARCHAR(128) NOT NULL DEFAULT '';"
err = reg.Persister().Connection(ctx).RawQuery(statement).Exec()
assert.NoError(t, err)

cl := &client.Client{
LegacyClientID: strconv.Itoa(rand.Int()),
}

err = reg.Persister().CreateClient(ctx, cl)
assert.NoError(t, err)

readClients := make([]client.Client, 0)
err = reg.Persister().Connection(ctx).RawQuery("SELECT * FROM \"hydra_client\"").All(&readClients)
if test.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing destination name temp_column")
} else {
assert.NoError(t, err)
}
assert.Len(t, readClients, test.expectedSize)
},
)
}
}

func sussessfulPing() func(r *RegistrySQL) error {
return func(r *RegistrySQL) error {
// fake that ping is successful
Expand Down
12 changes: 12 additions & 0 deletions spec/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@
}
},
"properties": {
"db": {
"type": "object",
"additionalProperties": false,
"description": "Configures the database connection",
"properties": {
"ignore_unknown_table_columns": {
"type": "boolean",
"description": "Ignore scan errors when columns in the SQL result have no fields in the destination struct",
"default": false
}
}
},
"log": {
"type": "object",
"additionalProperties": false,
Expand Down

0 comments on commit 5842946

Please sign in to comment.