Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(selfservice): Login self service flow with TOTP does not pass on return_to URL #2175

Merged
merged 4 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ func NewHookExecutor(d executorDependencies) *HookExecutor {
return &HookExecutor{d: d}
}

func (e *HookExecutor) requiresAAL2(r *http.Request, s *session.Session) (*session.ErrAALNotSatisfied, bool) {
func (e *HookExecutor) requiresAAL2(r *http.Request, s *session.Session, a *Flow) (*session.ErrAALNotSatisfied, bool) {
var aalErr *session.ErrAALNotSatisfied
err := e.d.SessionManager().DoesSessionSatisfy(r, s, e.d.Config(r.Context()).SessionWhoAmIAAL())
return aalErr, errors.As(err, &aalErr)
if ok := errors.As(err, &aalErr); !ok {
return nil, false
}

if err := aalErr.PassReturnToParameter(a.RequestURL); err != nil {
return nil, false
}

return aalErr, true
}

func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *Flow, i *identity.Identity, s *session.Session) error {
Expand Down Expand Up @@ -127,7 +135,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
Info("Identity authenticated successfully and was issued an Ory Kratos Session Token.")

response := &APIFlowResponse{Session: s, Token: s.Token}
if _, required := e.requiresAAL2(r, s); required {
if _, required := e.requiresAAL2(r, s, a); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
Expand All @@ -151,7 +159,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
s.Token = ""

response := &APIFlowResponse{Session: s}
if _, required := e.requiresAAL2(r, s); required {
if _, required := e.requiresAAL2(r, s, a); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
Expand All @@ -160,7 +168,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
}

// If we detect that whoami would require a higher AAL, we redirect!
if aalErr, required := e.requiresAAL2(r, s); required {
if aalErr, required := e.requiresAAL2(r, s, a); required {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
}
Expand Down
90 changes: 62 additions & 28 deletions selfservice/strategy/totp/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package totp_test
import (
"bytes"
"context"
_ "embed"
"fmt"
"net/http"
"net/url"
Expand All @@ -22,6 +23,8 @@ import (

stdtotp "github.com/pquerna/otp/totp"

"github.com/ory/x/sqlxx"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
Expand All @@ -30,21 +33,18 @@ import (
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/totp"
"github.com/ory/kratos/x"
"github.com/ory/x/sqlxx"

_ "embed"
)

const totpCodeGJSONQuery = "ui.nodes.#(attributes.name==totp_code)"

func createIdentityWithoutTOTP(t *testing.T, reg driver.Registry) *identity.Identity {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
delete(id.Credentials, identity.CredentialsTypeTOTP)
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), id))
return id
}

func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, *otp.Key) {
func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, string, *otp.Key) {
identifier := x.NewUUID().String() + "@ory.sh"
password := x.NewUUID().String()
key, err := totp.NewKey(context.Background(), "foo", reg)
Expand Down Expand Up @@ -75,13 +75,14 @@ func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, *otp
},
}
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), i))
return i, key
return i, password, key
}

func TestCompleteLogin(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypePassword), map[string]interface{}{"enabled": true})
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeTOTP), map[string]interface{}{"enabled": true})
conf.MustSet(config.ViperKeyURLsWhitelistedReturnToDomains, []string{"https://www.ory.sh"})

router := x.NewRouterPublic()
publicTS, _ := testhelpers.NewKratosServerWithRouters(t, reg, router, x.NewRouterAdmin())
Expand All @@ -98,7 +99,7 @@ func TestCompleteLogin(t *testing.T) {
conf.MustSet(config.ViperKeySecretsDefault, []string{"not-a-secure-session-key"})

t.Run("case=totp payload is set when identity has totp", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
f := testhelpers.InitializeLoginFlowViaAPI(t, apiClient, publicTS, false, testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2))
Expand All @@ -116,7 +117,7 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should show the error ui because the request payload is malformed", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

t.Run("type=api", func(t *testing.T) {
apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
Expand Down Expand Up @@ -159,9 +160,15 @@ func TestCompleteLogin(t *testing.T) {
return testhelpers.LoginMakeRequest(t, true, false, f, apiClient, payload)
}

doBrowserFlow := func(t *testing.T, spa bool, v func(url.Values), id *identity.Identity) (string, *http.Response) {
doBrowserFlow := func(t *testing.T, spa bool, v func(url.Values), id *identity.Identity, returnTo string) (string, *http.Response) {
browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, reg, id)
f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, spa, testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2))

opts := []testhelpers.InitFlowWithOption{testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2)}
if len(returnTo) > 0 {
opts = append(opts, testhelpers.InitFlowWithReturnTo(returnTo))
}

f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, spa, opts...)
values := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
values.Set("method", "totp")
v(values)
Expand All @@ -177,7 +184,7 @@ func TestCompleteLogin(t *testing.T) {
}

t.Run("case=should fail if code is empty", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "")
}
Expand All @@ -194,18 +201,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should fail if code is invalid", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "111111")
}
Expand All @@ -222,18 +229,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should fail if code is too long", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "1111111111")
}
Expand All @@ -250,12 +257,12 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})
Expand All @@ -279,18 +286,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should pass when TOTP is supplied correctly", func(t *testing.T) {
id, key := createIdentity(t, reg)
id, _, key := createIdentity(t, reg)
code, err := stdtotp.GenerateCode(key.Secret(), time.Now())
require.NoError(t, err)
payload := func(v url.Values) {
Expand Down Expand Up @@ -323,12 +330,19 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=browser set return_to", func(t *testing.T) {
returnTo := "https://www.ory.sh"
_, res := doBrowserFlow(t, false, payload, id, returnTo)
t.Log(res.Request.URL.String())
assert.Contains(t, res.Request.URL.String(), returnTo)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})
Expand Down Expand Up @@ -356,7 +370,7 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should pass without csrf if API flow", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
body, res := doAPIFlow(t, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
Expand All @@ -367,12 +381,12 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should fail if CSRF token is invalid", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
}, id)
}, id, "")

assert.Contains(t, res.Request.URL.String(), errTS.URL)
assert.Equal(t, x.ErrInvalidCSRFToken.Reason(), gjson.Get(body, "reason").String(), body)
Expand All @@ -382,10 +396,30 @@ func TestCompleteLogin(t *testing.T) {
body, res := doBrowserFlow(t, true, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
}, id)
}, id, "")

assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
assert.Equal(t, x.ErrInvalidCSRFToken.Reason(), gjson.Get(body, "error.reason").String(), body)
})
})

t.Run("case=should pass return_to URL after login", func(t *testing.T) {
id, pwd, _ := createIdentity(t, reg)

t.Run("type=browser", func(t *testing.T) {
returnTo := "https://www.ory.sh"
browserClient := testhelpers.NewClientWithCookies(t)
f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, false, testhelpers.InitFlowWithReturnTo(returnTo))

cred, ok := id.GetCredentials(identity.CredentialsTypePassword)
require.True(t, ok)
values := url.Values{"method": {"password"}, "password_identifier": {cred.Identifiers[0]},
"password": {pwd}, "csrf_token": {x.FakeCSRFToken}}.Encode()

body, res := testhelpers.LoginMakeRequest(t, false, false, f, browserClient, values)
require.Contains(t, res.Request.URL.Path, "login", "%s", res.Request.URL.String())
assert.Equal(t, gjson.Get(body, "requested_aal").String(), "aal2", "%s", body)
assert.Equal(t, gjson.Get(body, "return_to").String(), returnTo, "%s", body)
})
})
}
20 changes: 10 additions & 10 deletions selfservice/strategy/totp/settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package totp_test

import (
"context"
_ "embed"
"encoding/json"
"net/http"
"net/url"
Expand All @@ -23,18 +24,17 @@ import (
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/text"
"github.com/ory/x/assertx"
"github.com/ory/x/sqlcon"

"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/text"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/x"

_ "embed"
)

func TestCompleteSettings(t *testing.T) {
Expand All @@ -58,7 +58,7 @@ func TestCompleteSettings(t *testing.T) {
conf.MustSet(config.ViperKeySecretsDefault, []string{"not-a-secure-session-key"})

t.Run("case=device unlinking is available when identity has totp", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
f := testhelpers.InitializeSettingsFlowViaAPI(t, apiClient, publicTS)
Expand All @@ -68,7 +68,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("case=device setup is available when identity has no totp yet", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
id.Credentials = nil
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), id))

Expand Down Expand Up @@ -143,7 +143,7 @@ func TestCompleteSettings(t *testing.T) {
conf.MustSet(config.ViperKeySelfServiceSettingsPrivilegedAuthenticationAfter, "5m")
})

id, key := createIdentity(t, reg)
id, _, key := createIdentity(t, reg)
var payload = func(v url.Values) {
v.Set("totp_unlink", "true")
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestCompleteSettings(t *testing.T) {
}

t.Run("type=api", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doAPIFlow(t, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow)
Expand All @@ -240,7 +240,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("type=spa", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doBrowserFlow(t, true, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow)
Expand All @@ -249,7 +249,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doBrowserFlow(t, false, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), uiTS.URL)
Expand Down
Loading