Skip to content

Commit

Permalink
Fix to send challenge only if provider supports PKCE (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
int128 authored May 8, 2020
1 parent 4ffc914 commit 175275b
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 116 deletions.
18 changes: 17 additions & 1 deletion pkg/adaptors/oidcclient/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package oidcclient
import (
"context"
"crypto/tls"
"fmt"
"net/http"

"github.com/coreos/go-oidc"
Expand Down Expand Up @@ -49,6 +50,10 @@ func New(ctx context.Context, config Config) (Interface, error) {
if err != nil {
return nil, xerrors.Errorf("oidc discovery error: %w", err)
}
supportedPKCEMethods, err := extractSupportedPKCEMethods(provider)
if err != nil {
return nil, xerrors.Errorf("could not determine supported PKCE methods: %w", err)
}
return &client{
httpClient: httpClient,
provider: provider,
Expand All @@ -58,6 +63,17 @@ func New(ctx context.Context, config Config) (Interface, error) {
ClientSecret: config.ClientSecret,
Scopes: append(config.ExtraScopes, oidc.ScopeOpenID),
},
logger: config.Logger,
logger: config.Logger,
supportedPKCEMethods: supportedPKCEMethods,
}, nil
}

func extractSupportedPKCEMethods(provider *oidc.Provider) ([]string, error) {
var d struct {
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
}
if err := provider.Claims(&d); err != nil {
return nil, fmt.Errorf("invalid discovery document: %w", err)
}
return d.CodeChallengeMethodsSupported, nil
}
42 changes: 28 additions & 14 deletions pkg/adaptors/oidcclient/mock_oidcclient/mock_oidcclient.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 50 additions & 38 deletions pkg/adaptors/oidcclient/oidcclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/google/wire"
"github.com/int128/kubelogin/pkg/adaptors/logger"
"github.com/int128/kubelogin/pkg/domain/jwt"
"github.com/int128/kubelogin/pkg/domain/pkce"
"github.com/int128/oauth2cli"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
Expand All @@ -26,31 +27,29 @@ type Interface interface {
GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeInput, localServerReadyChan chan<- string) (*TokenSet, error)
GetTokenByROPC(ctx context.Context, username, password string) (*TokenSet, error)
Refresh(ctx context.Context, refreshToken string) (*TokenSet, error)
SupportedPKCEMethods() []string
}

type AuthCodeURLInput struct {
State string
Nonce string
CodeChallenge string
CodeChallengeMethod string
PKCEParams pkce.Params
RedirectURI string
AuthRequestExtraParams map[string]string
}

type ExchangeAuthCodeInput struct {
Code string
CodeVerifier string
Nonce string
RedirectURI string
Code string
PKCEParams pkce.Params
Nonce string
RedirectURI string
}

type GetTokenByAuthCodeInput struct {
BindAddress []string
State string
Nonce string
CodeChallenge string
CodeChallengeMethod string
CodeVerifier string
PKCEParams pkce.Params
RedirectURLHostname string
AuthRequestExtraParams map[string]string
}
Expand All @@ -64,10 +63,11 @@ type TokenSet struct {
}

type client struct {
httpClient *http.Client
provider *oidc.Provider
oauth2Config oauth2.Config
logger logger.Interface
httpClient *http.Client
provider *oidc.Provider
oauth2Config oauth2.Config
logger logger.Interface
supportedPKCEMethods []string
}

func (c *client) wrapContext(ctx context.Context) context.Context {
Expand All @@ -81,25 +81,14 @@ func (c *client) wrapContext(ctx context.Context) context.Context {
func (c *client) GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeInput, localServerReadyChan chan<- string) (*TokenSet, error) {
ctx = c.wrapContext(ctx)
config := oauth2cli.Config{
OAuth2Config: c.oauth2Config,
State: in.State,
AuthCodeOptions: []oauth2.AuthCodeOption{
oauth2.AccessTypeOffline,
oidc.Nonce(in.Nonce),
oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", in.CodeChallengeMethod),
},
TokenRequestOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", in.CodeVerifier),
},
OAuth2Config: c.oauth2Config,
State: in.State,
AuthCodeOptions: authorizationRequestOptions(in.Nonce, in.PKCEParams, in.AuthRequestExtraParams),
TokenRequestOptions: tokenRequestOptions(in.PKCEParams),
LocalServerBindAddress: in.BindAddress,
LocalServerReadyChan: localServerReadyChan,
RedirectURLHostname: in.RedirectURLHostname,
}
for key, value := range in.AuthRequestExtraParams {
config.AuthCodeOptions = append(config.AuthCodeOptions, oauth2.SetAuthURLParam(key, value))
}

token, err := oauth2cli.GetToken(ctx, config)
if err != nil {
return nil, xerrors.Errorf("oauth2 error: %w", err)
Expand All @@ -111,15 +100,7 @@ func (c *client) GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeIn
func (c *client) GetAuthCodeURL(in AuthCodeURLInput) string {
cfg := c.oauth2Config
cfg.RedirectURL = in.RedirectURI
opts := []oauth2.AuthCodeOption{
oauth2.AccessTypeOffline,
oidc.Nonce(in.Nonce),
oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", in.CodeChallengeMethod),
}
for key, value := range in.AuthRequestExtraParams {
opts = append(opts, oauth2.SetAuthURLParam(key, value))
}
opts := authorizationRequestOptions(in.Nonce, in.PKCEParams, in.AuthRequestExtraParams)
return cfg.AuthCodeURL(in.State, opts...)
}

Expand All @@ -128,13 +109,44 @@ func (c *client) ExchangeAuthCode(ctx context.Context, in ExchangeAuthCodeInput)
ctx = c.wrapContext(ctx)
cfg := c.oauth2Config
cfg.RedirectURL = in.RedirectURI
token, err := cfg.Exchange(ctx, in.Code, oauth2.SetAuthURLParam("code_verifier", in.CodeVerifier))
opts := tokenRequestOptions(in.PKCEParams)
token, err := cfg.Exchange(ctx, in.Code, opts...)
if err != nil {
return nil, xerrors.Errorf("exchange error: %w", err)
}
return c.verifyToken(ctx, token, in.Nonce)
}

func authorizationRequestOptions(n string, p pkce.Params, e map[string]string) []oauth2.AuthCodeOption {
o := []oauth2.AuthCodeOption{
oauth2.AccessTypeOffline,
oidc.Nonce(n),
}
if !p.IsPlain() {
o = append(o,
oauth2.SetAuthURLParam("code_challenge", p.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", p.CodeChallengeMethod),
)
}
for key, value := range e {
o = append(o, oauth2.SetAuthURLParam(key, value))
}
return o
}

func tokenRequestOptions(p pkce.Params) (o []oauth2.AuthCodeOption) {
if !p.IsPlain() {
o = append(o, oauth2.SetAuthURLParam("code_verifier", p.CodeVerifier))
}
return
}

// SupportedPKCEMethods returns the PKCE methods supported by the provider.
// This may return nil if PKCE is not supported.
func (c *client) SupportedPKCEMethods() []string {
return c.supportedPKCEMethods
}

// GetTokenByROPC performs the resource owner password credentials flow.
func (c *client) GetTokenByROPC(ctx context.Context, username, password string) (*TokenSet, error) {
ctx = c.wrapContext(ctx)
Expand Down
27 changes: 0 additions & 27 deletions pkg/domain/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package oidc

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"

Expand All @@ -25,21 +24,6 @@ func NewNonce() (string, error) {
return base64URLEncode(b), nil
}

type PKCEParams struct {
CodeChallenge string
CodeChallengeMethod string
CodeVerifier string
}

func NewPKCEParams() (*PKCEParams, error) {
b, err := random32()
if err != nil {
return nil, xerrors.Errorf("could not generate a random: %w", err)
}
s := computeS256(b)
return &s, nil
}

func random32() ([]byte, error) {
b := make([]byte, 32)
if err := binary.Read(rand.Reader, binary.LittleEndian, b); err != nil {
Expand All @@ -48,17 +32,6 @@ func random32() ([]byte, error) {
return b, nil
}

func computeS256(b []byte) PKCEParams {
v := base64URLEncode(b)
s := sha256.New()
_, _ = s.Write([]byte(v))
return PKCEParams{
CodeChallenge: base64URLEncode(s.Sum(nil)),
CodeChallengeMethod: "S256",
CodeVerifier: v,
}
}

func base64URLEncode(b []byte) string {
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b)
}
25 changes: 0 additions & 25 deletions pkg/domain/oidc/oidc_test.go

This file was deleted.

Loading

0 comments on commit 175275b

Please sign in to comment.