diff --git a/driver/config/config.go b/driver/config/config.go index 92df3ca5630c..b03b6a247d2f 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "runtime" + "strings" "time" "github.com/markbates/pkger" @@ -561,13 +562,29 @@ func (p *Config) CourierTemplatesRoot() string { return p.p.StringF(ViperKeyCourierTemplatesPath, "/courier/template/templates") } +func splitUrlAndFragment(s string) (string, string) { + i := strings.IndexByte(s, '#') + if i < 0 { + return s, "" + } + return s[:i], s[i+1:] +} + func (p *Config) parseURIOrFail(key string) *url.URL { - u, err := url.ParseRequestURI(p.p.String(key)) + u, frag := splitUrlAndFragment(p.p.String(key)) + url, err := url.ParseRequestURI(u) if err != nil { p.l.WithError(errors.WithStack(err)). Fatalf("Configuration value from key %s is not a valid URL: %s", key, p.p.String(key)) } - return u + if url.Host == "" || url.Scheme == "" { + p.l.Fatalf("Configuration value from key %s is not a valid URL: %s", key, p.p.String(key)) + } + + if frag != "" { + url.Fragment = frag + } + return url } func (p *Config) Tracing() *tracing.Config { diff --git a/driver/config/config_test.go b/driver/config/config_test.go index 755b46e25d92..b87465073820 100644 --- a/driver/config/config_test.go +++ b/driver/config/config_test.go @@ -11,6 +11,7 @@ import ( "github.com/ory/x/logrusx" "github.com/ory/x/urlx" + "github.com/sirupsen/logrus/hooks/test" _ "github.com/ory/jsonschema/v3/fileloader" @@ -44,6 +45,47 @@ func TestViperProvider(t *testing.T) { "http://return-to-1-test.ory.sh/", "http://return-to-2-test.ory.sh/", }, ds) + + pWithFragments := config.MustNew(logrusx.New("", ""), + configx.WithValues(map[string]interface{}{ + config.ViperKeySelfServiceLoginUI: "http://test.kratos.ory.sh/#/login", + config.ViperKeySelfServiceSettingsURL: "http://test.kratos.ory.sh/#/settings", + config.ViperKeySelfServiceRegistrationUI: "http://test.kratos.ory.sh/#/register", + config.ViperKeySelfServiceErrorUI: "http://test.kratos.ory.sh/#/error", + }), + configx.SkipValidation(), + ) + + assert.Equal(t, "http://test.kratos.ory.sh/#/login", pWithFragments.SelfServiceFlowLoginUI().String()) + assert.Equal(t, "http://test.kratos.ory.sh/#/settings", pWithFragments.SelfServiceFlowSettingsUI().String()) + assert.Equal(t, "http://test.kratos.ory.sh/#/register", pWithFragments.SelfServiceFlowRegistrationUI().String()) + assert.Equal(t, "http://test.kratos.ory.sh/#/error", pWithFragments.SelfServiceFlowErrorURL().String()) + + for _, v := range []string{ + "#/login", + "/login", + "/", + "test.kratos.ory.sh/login", + } { + + logger := logrusx.New("", "") + logger.Logger.ExitFunc = func(code int) { panic("") } + hook := new(test.Hook) + logger.Logger.Hooks.Add(hook) + + pWithIncorrectUrls := config.MustNew(logger, + configx.WithValues(map[string]interface{}{ + config.ViperKeySelfServiceLoginUI: v, + }), + configx.SkipValidation(), + ) + + assert.Panics(t, func() { pWithIncorrectUrls.SelfServiceFlowLoginUI() }) + + assert.Equal(t, logrus.FatalLevel, hook.LastEntry().Level) + assert.Equal(t, "Configuration value from key selfservice.flows.login.ui_url is not a valid URL: "+v, hook.LastEntry().Message) + assert.Equal(t, 1, len(hook.Entries)) + } }) t.Run("group=default_return_to", func(t *testing.T) {