Skip to content

Commit

Permalink
implement sts token exchange
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Liao <[email protected]>
  • Loading branch information
seankhliao committed Mar 24, 2023
1 parent 2bb4896 commit b25d49b
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 82 deletions.
4 changes: 4 additions & 0 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ func (p *password) UnmarshalJSON(b []byte) error {

// OAuth2 describes enabled OAuth2 extensions.
type OAuth2 struct {
// list of allowed grant types,
// defaults to all supported types
GrantTypes []string `json:"grantTypes"`

ResponseTypes []string `json:"responseTypes"`
// If specified, do not prompt the user to approve client authorization. The
// act of logging in implies authorization.
Expand Down
7 changes: 7 additions & 0 deletions cmd/dex/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ staticClients:
oauth2:
alwaysShowLoginScreen: true
grantTypes:
- refresh_token
- "urn:ietf:params:oauth:grant-type:token-exchange"
connectors:
- type: mockCallback
Expand Down Expand Up @@ -161,6 +164,10 @@ logger:
},
OAuth2: OAuth2{
AlwaysShowLoginScreen: true,
GrantTypes: []string{
"refresh_token",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
},
StaticConnectors: []Connector{
{
Expand Down
10 changes: 10 additions & 0 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func runServe(options serveOptions) error {
healthChecker := gosundheit.New()

serverConfig := server.Config{
AllowedGrantTypes: c.OAuth2.GrantTypes,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
Expand Down Expand Up @@ -554,6 +555,15 @@ func applyConfigOverrides(options serveOptions, config *Config) {
if config.Frontend.Dir == "" {
config.Frontend.Dir = os.Getenv("DEX_FRONTEND_DIR")
}

if len(config.OAuth2.GrantTypes) == 0 {
config.OAuth2.GrantTypes = []string{
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
}
}
}

func pprofHandler(router *http.ServeMux) {
Expand Down
4 changes: 4 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ type RefreshConnector interface {
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}

type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error)
}
4 changes: 4 additions & 0 deletions connector/mock/connectortest.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func (m *Callback) Refresh(ctx context.Context, s connector.Scopes, identity con
return m.Identity, nil
}

func (m *Callback) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
return m.Identity, nil
}

// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
type CallbackConfig struct{}

Expand Down
20 changes: 19 additions & 1 deletion connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type caller uint
const (
createCaller caller = iota
refreshCaller
exchangeCaller
)

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
Expand Down Expand Up @@ -284,12 +285,29 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return c.createIdentity(ctx, identity, token, refreshCaller)
}

func (c *oidcConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
var identity connector.Identity
token := &oauth2.Token{
AccessToken: subjectToken,
}
if subjectTokenType == "urn:ietf:params:oauth:token-type:access_token" {
token = token.WithExtra(map[string]any{
"id_token": subjectToken,
})
}
return c.createIdentity(ctx, identity, token, exchangeCaller)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
var claims map[string]interface{}

rawIDToken, ok := token.Extra("id_token").(string)
if ok {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
verifier := c.verifier
if caller == exchangeCaller {
verifier = c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true})
}
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
Expand Down
133 changes: 120 additions & 13 deletions server/handlers.go
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
implicitOrHybrid = true
var err error

accessToken, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -830,6 +830,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}

grantType := r.PostFormValue("grant_type")
if !contains(s.supportedGrantTypes, grantType) {
s.logger.Errorf("unsupported grant type: %v", grantType)
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
return
}
switch grantType {
case grantTypeDeviceCode:
s.handleDeviceToken(w, r)
Expand All @@ -839,6 +844,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
case grantTypeTokenExchange:
s.handleTokenExchange(w, r)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
Expand Down Expand Up @@ -917,7 +924,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
}

func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1180,7 +1187,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}

accessToken, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
if err != nil {
s.logger.Errorf("password grant failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1319,21 +1326,121 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.writeAccessToken(w, resp)
}

func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// TODO: check global allowed grant types?

if err := r.ParseForm(); err != nil {
s.logger.Errorf("Could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
q := r.Form

connID := q.Get("audience") // OPTIONAL, use for connector ID
resource := q.Get("resource") // OPTIONAL, use for issued token audience
scopes := strings.Fields(q.Get("scope")) // OPTIONAL, map to issed token scope
requestedTokenType := q.Get("requested_token_type") // OPTIONAL, default to access token
if requestedTokenType == "" {
requestedTokenType = tokenTypeAccess
}
subjectToken := q.Get("subject_token") // REQUIRED
subjectTokenType := q.Get("subject_token_type") // REQUIRED

switch subjectTokenType {
case tokenTypeID, tokenTypeAccess: // ok, continue
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest)
return
}

if subjectToken == "" {
s.tokenErrHelper(w, errInvalidRequest, "Missing subject_token", http.StatusBadRequest)
return
}

conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("Failed to get connector: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return
}
teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok {
s.logger.Errorf("Connector doesn't implement token exchange: %v", connID)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return
}
identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken)
if err != nil {
s.logger.Errorf("Failed to verify subject token: %v", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
return
}

claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
PreferredUsername: identity.PreferredUsername,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}
resp := accessTokenResponse{
IssuedTokenType: requestedTokenType,
TokenType: "bearer",
}
var expiry time.Time
switch requestedTokenType {
case tokenTypeID:
resp.AccessToken, expiry, err = s.newIDToken(resource, claims, scopes, "", "", "", connID)
case tokenTypeAccess:
resp.AccessToken, expiry, err = s.newAccessToken(resource, claims, scopes, "", connID)
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
return
}
if err != nil {
s.logger.Errorf("token exchange failed to create new %v token: %v", requestedTokenType, err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp.ExpiresIn = int(time.Until(expiry).Seconds())

data, err := json.Marshal(resp)
if err != nil {
s.logger.Errorf("failed to marshal token exchange response: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))

// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.Write(data)
}

type accessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}

func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenResponse {
return &accessTokenResponse{
accessToken,
"bearer",
int(expiry.Sub(s.now()).Seconds()),
refreshToken,
idToken,
AccessToken: accessToken,
TokenType: "bearer",
ExpiresIn: int(expiry.Sub(s.now()).Seconds()),
RefreshToken: refreshToken,
IDToken: idToken,
}
}

Expand Down
Loading

0 comments on commit b25d49b

Please sign in to comment.