Skip to content

Commit

Permalink
fix(selfservice): login self service flow with TOTP does not pass on …
Browse files Browse the repository at this point in the history
…return_to URL (ory#2175)

Closes ory#2172
  • Loading branch information
sawadashota authored Feb 1, 2022
1 parent 8a0df62 commit edafab6
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 43 deletions.
18 changes: 13 additions & 5 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ func NewHookExecutor(d executorDependencies) *HookExecutor {
return &HookExecutor{d: d}
}

func (e *HookExecutor) requiresAAL2(r *http.Request, s *session.Session) (*session.ErrAALNotSatisfied, bool) {
func (e *HookExecutor) requiresAAL2(r *http.Request, s *session.Session, a *Flow) (*session.ErrAALNotSatisfied, bool) {
var aalErr *session.ErrAALNotSatisfied
err := e.d.SessionManager().DoesSessionSatisfy(r, s, e.d.Config(r.Context()).SessionWhoAmIAAL())
return aalErr, errors.As(err, &aalErr)
if ok := errors.As(err, &aalErr); !ok {
return nil, false
}

if err := aalErr.PassReturnToParameter(a.RequestURL); err != nil {
return nil, false
}

return aalErr, true
}

func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *Flow, i *identity.Identity, s *session.Session) error {
Expand Down Expand Up @@ -127,7 +135,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
Info("Identity authenticated successfully and was issued an Ory Kratos Session Token.")

response := &APIFlowResponse{Session: s, Token: s.Token}
if _, required := e.requiresAAL2(r, s); required {
if _, required := e.requiresAAL2(r, s, a); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
Expand All @@ -151,7 +159,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
s.Token = ""

response := &APIFlowResponse{Session: s}
if _, required := e.requiresAAL2(r, s); required {
if _, required := e.requiresAAL2(r, s, a); required {
// If AAL is not satisfied, we omit the identity to preserve the user's privacy in case of a phishing attack.
response.Session.Identity = nil
}
Expand All @@ -160,7 +168,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
}

// If we detect that whoami would require a higher AAL, we redirect!
if aalErr, required := e.requiresAAL2(r, s); required {
if aalErr, required := e.requiresAAL2(r, s, a); required {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
}
Expand Down
90 changes: 62 additions & 28 deletions selfservice/strategy/totp/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package totp_test
import (
"bytes"
"context"
_ "embed"
"fmt"
"net/http"
"net/url"
Expand All @@ -22,6 +23,8 @@ import (

stdtotp "github.com/pquerna/otp/totp"

"github.com/ory/x/sqlxx"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
Expand All @@ -30,21 +33,18 @@ import (
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/totp"
"github.com/ory/kratos/x"
"github.com/ory/x/sqlxx"

_ "embed"
)

const totpCodeGJSONQuery = "ui.nodes.#(attributes.name==totp_code)"

func createIdentityWithoutTOTP(t *testing.T, reg driver.Registry) *identity.Identity {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
delete(id.Credentials, identity.CredentialsTypeTOTP)
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), id))
return id
}

func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, *otp.Key) {
func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, string, *otp.Key) {
identifier := x.NewUUID().String() + "@ory.sh"
password := x.NewUUID().String()
key, err := totp.NewKey(context.Background(), "foo", reg)
Expand Down Expand Up @@ -75,13 +75,14 @@ func createIdentity(t *testing.T, reg driver.Registry) (*identity.Identity, *otp
},
}
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), i))
return i, key
return i, password, key
}

func TestCompleteLogin(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypePassword), map[string]interface{}{"enabled": true})
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeTOTP), map[string]interface{}{"enabled": true})
conf.MustSet(config.ViperKeyURLsWhitelistedReturnToDomains, []string{"https://www.ory.sh"})

router := x.NewRouterPublic()
publicTS, _ := testhelpers.NewKratosServerWithRouters(t, reg, router, x.NewRouterAdmin())
Expand All @@ -98,7 +99,7 @@ func TestCompleteLogin(t *testing.T) {
conf.MustSet(config.ViperKeySecretsDefault, []string{"not-a-secure-session-key"})

t.Run("case=totp payload is set when identity has totp", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
f := testhelpers.InitializeLoginFlowViaAPI(t, apiClient, publicTS, false, testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2))
Expand All @@ -116,7 +117,7 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should show the error ui because the request payload is malformed", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

t.Run("type=api", func(t *testing.T) {
apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
Expand Down Expand Up @@ -159,9 +160,15 @@ func TestCompleteLogin(t *testing.T) {
return testhelpers.LoginMakeRequest(t, true, false, f, apiClient, payload)
}

doBrowserFlow := func(t *testing.T, spa bool, v func(url.Values), id *identity.Identity) (string, *http.Response) {
doBrowserFlow := func(t *testing.T, spa bool, v func(url.Values), id *identity.Identity, returnTo string) (string, *http.Response) {
browserClient := testhelpers.NewHTTPClientWithIdentitySessionCookie(t, reg, id)
f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, spa, testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2))

opts := []testhelpers.InitFlowWithOption{testhelpers.InitFlowWithAAL(identity.AuthenticatorAssuranceLevel2)}
if len(returnTo) > 0 {
opts = append(opts, testhelpers.InitFlowWithReturnTo(returnTo))
}

f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, spa, opts...)
values := testhelpers.SDKFormFieldsToURLValues(f.Ui.Nodes)
values.Set("method", "totp")
v(values)
Expand All @@ -177,7 +184,7 @@ func TestCompleteLogin(t *testing.T) {
}

t.Run("case=should fail if code is empty", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "")
}
Expand All @@ -194,18 +201,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should fail if code is invalid", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "111111")
}
Expand All @@ -222,18 +229,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should fail if code is too long", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
payload := func(v url.Values) {
v.Set("totp_code", "1111111111")
}
Expand All @@ -250,12 +257,12 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})
Expand All @@ -279,18 +286,18 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})

t.Run("case=should pass when TOTP is supplied correctly", func(t *testing.T) {
id, key := createIdentity(t, reg)
id, _, key := createIdentity(t, reg)
code, err := stdtotp.GenerateCode(key.Secret(), time.Now())
require.NoError(t, err)
payload := func(v url.Values) {
Expand Down Expand Up @@ -323,12 +330,19 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, payload, id)
body, res := doBrowserFlow(t, false, payload, id, "")
check(t, true, body, res)
})

t.Run("type=browser set return_to", func(t *testing.T) {
returnTo := "https://www.ory.sh"
_, res := doBrowserFlow(t, false, payload, id, returnTo)
t.Log(res.Request.URL.String())
assert.Contains(t, res.Request.URL.String(), returnTo)
})

t.Run("type=spa", func(t *testing.T) {
body, res := doBrowserFlow(t, true, payload, id)
body, res := doBrowserFlow(t, true, payload, id, "")
check(t, false, body, res)
})
})
Expand Down Expand Up @@ -356,7 +370,7 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should pass without csrf if API flow", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
body, res := doAPIFlow(t, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
Expand All @@ -367,12 +381,12 @@ func TestCompleteLogin(t *testing.T) {
})

t.Run("case=should fail if CSRF token is invalid", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
t.Run("type=browser", func(t *testing.T) {
body, res := doBrowserFlow(t, false, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
}, id)
}, id, "")

assert.Contains(t, res.Request.URL.String(), errTS.URL)
assert.Equal(t, x.ErrInvalidCSRFToken.Reason(), gjson.Get(body, "reason").String(), body)
Expand All @@ -382,10 +396,30 @@ func TestCompleteLogin(t *testing.T) {
body, res := doBrowserFlow(t, true, func(v url.Values) {
v.Del("csrf_token")
v.Set("totp_code", "111111")
}, id)
}, id, "")

assert.Contains(t, res.Request.URL.String(), publicTS.URL+login.RouteSubmitFlow)
assert.Equal(t, x.ErrInvalidCSRFToken.Reason(), gjson.Get(body, "error.reason").String(), body)
})
})

t.Run("case=should pass return_to URL after login", func(t *testing.T) {
id, pwd, _ := createIdentity(t, reg)

t.Run("type=browser", func(t *testing.T) {
returnTo := "https://www.ory.sh"
browserClient := testhelpers.NewClientWithCookies(t)
f := testhelpers.InitializeLoginFlowViaBrowser(t, browserClient, publicTS, false, false, testhelpers.InitFlowWithReturnTo(returnTo))

cred, ok := id.GetCredentials(identity.CredentialsTypePassword)
require.True(t, ok)
values := url.Values{"method": {"password"}, "password_identifier": {cred.Identifiers[0]},
"password": {pwd}, "csrf_token": {x.FakeCSRFToken}}.Encode()

body, res := testhelpers.LoginMakeRequest(t, false, false, f, browserClient, values)
require.Contains(t, res.Request.URL.Path, "login", "%s", res.Request.URL.String())
assert.Equal(t, gjson.Get(body, "requested_aal").String(), "aal2", "%s", body)
assert.Equal(t, gjson.Get(body, "return_to").String(), returnTo, "%s", body)
})
})
}
20 changes: 10 additions & 10 deletions selfservice/strategy/totp/settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package totp_test

import (
"context"
_ "embed"
"encoding/json"
"net/http"
"net/url"
Expand All @@ -23,18 +24,17 @@ import (
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/text"
"github.com/ory/x/assertx"
"github.com/ory/x/sqlcon"

"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/text"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/x"

_ "embed"
)

func TestCompleteSettings(t *testing.T) {
Expand All @@ -58,7 +58,7 @@ func TestCompleteSettings(t *testing.T) {
conf.MustSet(config.ViperKeySecretsDefault, []string{"not-a-secure-session-key"})

t.Run("case=device unlinking is available when identity has totp", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)

apiClient := testhelpers.NewHTTPClientWithIdentitySessionToken(t, reg, id)
f := testhelpers.InitializeSettingsFlowViaAPI(t, apiClient, publicTS)
Expand All @@ -68,7 +68,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("case=device setup is available when identity has no totp yet", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
id.Credentials = nil
require.NoError(t, reg.PrivilegedIdentityPool().UpdateIdentity(context.Background(), id))

Expand Down Expand Up @@ -143,7 +143,7 @@ func TestCompleteSettings(t *testing.T) {
conf.MustSet(config.ViperKeySelfServiceSettingsPrivilegedAuthenticationAfter, "5m")
})

id, key := createIdentity(t, reg)
id, _, key := createIdentity(t, reg)
var payload = func(v url.Values) {
v.Set("totp_unlink", "true")
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestCompleteSettings(t *testing.T) {
}

t.Run("type=api", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doAPIFlow(t, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow)
Expand All @@ -240,7 +240,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("type=spa", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doBrowserFlow(t, true, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), publicTS.URL+settings.RouteSubmitFlow)
Expand All @@ -249,7 +249,7 @@ func TestCompleteSettings(t *testing.T) {
})

t.Run("type=browser", func(t *testing.T) {
id, _ := createIdentity(t, reg)
id, _, _ := createIdentity(t, reg)
actual, res := doBrowserFlow(t, false, payload, id)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, res.Request.URL.String(), uiTS.URL)
Expand Down
Loading

0 comments on commit edafab6

Please sign in to comment.