Skip to content

Commit

Permalink
fix: improve token OIDC logging (supabase#1606)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Currently, when the "Unacceptable audience in id_token" error is
returned, it doesn't log the audience claim from the id token, which
makes it hard to debug. The audience claim from the id token is now
logged as well when this error is returned.
* Adds a basic test for the generic id token oidc `getProvider()`
method, since we currently have 0 coverage for this file
* The test also uncovered a possible nil pointer panic in the case of
the generic OIDC provider being returned since in the generic case, the
`oauthConfig` will be nil. Rather than returning the `oauthConfig`, we
only need to return the `skipNonceCheck` property since we only check
for that.
  • Loading branch information
kangmingtay authored Jun 3, 2024
1 parent 24cf102 commit d324df7
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
25 changes: 15 additions & 10 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type IdTokenGrantParams struct {
Issuer string `json:"issuer"`
}

func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, *conf.OAuthProviderConfiguration, string, []string, error) {
func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, error) {
log := observability.GetLogEntry(r).Entry

var cfg *conf.OAuthProviderConfiguration
Expand Down Expand Up @@ -54,7 +54,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa
if issuer == "" || !provider.IsAzureIssuer(issuer) {
detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken)
if err != nil {
return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err)
return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err)
}
issuer = detectedIssuer
}
Expand Down Expand Up @@ -95,20 +95,25 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa
}

if !allowed {
return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider))
return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider))
}

cfg = &conf.OAuthProviderConfiguration{
Enabled: true,
SkipNonceCheck: false,
}
}

if cfg != nil && !cfg.Enabled {
return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer))
if !cfg.Enabled {
return nil, false, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer))
}

oidcProvider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
return nil, nil, "", nil, err
return nil, false, "", nil, err
}

return oidcProvider, cfg, providerType, acceptableClientIDs, nil
return oidcProvider, cfg.SkipNonceCheck, providerType, acceptableClientIDs, nil
}

// IdTokenGrant implements the id_token grant type flow
Expand All @@ -131,7 +136,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
return oauthError("invalid request", "provider or client_id and issuer required")
}

oidcProvider, oauthConfig, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r)
oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r)
if err != nil {
return err
}
Expand Down Expand Up @@ -179,10 +184,10 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
}

if !correctAudience {
return oauthError("invalid request", "Unacceptable audience in id_token")
return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience))
}

if !oauthConfig.SkipNonceCheck {
if !skipNonceCheck {
tokenHasNonce := idToken.Nonce != ""
paramsHasNonce := params.Nonce != ""

Expand Down
69 changes: 69 additions & 0 deletions internal/api/token_oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package api

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
)

type TokenOIDCTestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
}

func TestTokenOIDC(t *testing.T) {
api, config, err := setupAPIForTest()
require.NoError(t, err)

ts := &TokenOIDCTestSuite{
API: api,
Config: config,
}
defer api.db.Close()

suite.Run(t, ts)
}

func SetupTestOIDCProvider(ts *TokenOIDCTestSuite) *httptest.Server {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"issuer":"` + server.URL + `","authorization_endpoint":"` + server.URL + `/authorize","token_endpoint":"` + server.URL + `/token","jwks_uri":"` + server.URL + `/jwks"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
return server
}

func (ts *TokenOIDCTestSuite) TestGetProvider() {
server := SetupTestOIDCProvider(ts)
defer server.Close()

params := &IdTokenGrantParams{
IdToken: "test-id-token",
AccessToken: "test-access-token",
Nonce: "test-nonce",
Provider: server.URL,
ClientID: "test-client-id",
Issuer: server.URL,
}

ts.Config.External.AllowedIdTokenIssuers = []string{server.URL}

req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
oidcProvider, skipNonceCheck, providerType, acceptableClientIds, err := params.getProvider(context.Background(), ts.Config, req)
require.NoError(ts.T(), err)
require.NotNil(ts.T(), oidcProvider)
require.False(ts.T(), skipNonceCheck)
require.Equal(ts.T(), params.Provider, providerType)
require.NotEmpty(ts.T(), acceptableClientIds)
}

0 comments on commit d324df7

Please sign in to comment.