Skip to content

Commit

Permalink
implement sts token exchange
Browse files Browse the repository at this point in the history
Signed-off-by: Sean Liao <[email protected]>
  • Loading branch information
seankhliao committed Mar 24, 2023
1 parent 2bb4896 commit 5aebff1
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 2 deletions.
4 changes: 4 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ type RefreshConnector interface {
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}

type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, idToken string) (Identity, error)
}
16 changes: 15 additions & 1 deletion connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type caller uint
const (
createCaller caller = iota
refreshCaller
exchangeCaller
)

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
Expand Down Expand Up @@ -284,12 +285,25 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return c.createIdentity(ctx, identity, token, refreshCaller)
}

func (c *oidcConnector) TokenIdentity(ctx context.Context, idToken string) (connector.Identity, error) {
var identity connector.Identity
token := &oauth2.Token{
AccessToken: idToken,
}
token = token.WithExtra(map[string]any{"id_token": idToken})
return c.createIdentity(ctx, identity, token, exchangeCaller)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
var claims map[string]interface{}

rawIDToken, ok := token.Extra("id_token").(string)
if ok {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
verifier := c.verifier
if caller == exchangeCaller {
verifier = c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true})
}
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
Expand Down
103 changes: 103 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
case grantTypeTokenExchange:
s.handleTokenExchange(w, r)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
Expand Down Expand Up @@ -1319,6 +1321,107 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.writeAccessToken(w, resp)
}

func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if err := r.ParseForm(); err != nil {
s.logger.Errorf("Could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
q := r.Form

connID := q.Get("audience") // OPTIONAL, use for connector ID
resource := q.Get("resource") // OPTIONAL, use for issued token audience
scopes := strings.Fields(q.Get("scope")) // OPTIONAL
requestedTokenType := q.Get("requested_token_type") // OPTIONAL
subjectToken := q.Get("subject_token") // REQUIRED
subjectTokenType := q.Get("subject_token_type") // REQUIRED

if subjectTokenType != tokenTypeID {
s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest)
return
}

conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("Failed to get connector: %v", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return
}
teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok {
s.logger.Errorf("Connector doesn't implement token exchange: %v", connID)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return
}
identity, err := teConn.TokenIdentity(ctx, subjectToken)
if err != nil {
s.logger.Errorf("Failed to verify subject token: %v", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
}

claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
PreferredUsername: identity.PreferredUsername,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}

resp := tokenExchangeResponse{
TokenType: "bearer",
}
switch requestedTokenType {
case tokenTypeID:
resp.IssuedTokenType = tokenTypeID
var expiry time.Time
resp.AccessToken, expiry, err = s.newIDToken(resource, claims, scopes, "", "", "", connID)
if err != nil {
s.logger.Errorf("token exchange failed to create new id token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp.ExpiresIn = int(time.Until(expiry).Seconds())
case tokenTypeAccess, "":
resp.IssuedTokenType = tokenTypeAccess
resp.AccessToken, err = s.newAccessToken(resource, claims, scopes, "", connID)
if err != nil {
s.logger.Errorf("token exchange failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
return
}

data, err := json.Marshal(resp)
if err != nil {
s.logger.Errorf("failed to marshal token exchange response: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))

// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.Write(data)
}

type tokenExchangeResponse struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}

type accessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand Down
13 changes: 12 additions & 1 deletion server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) er
return nil
}

//nolint
//nolint:all
const (
errInvalidRequest = "invalid_request"
errUnauthorizedClient = "unauthorized_client"
Expand Down Expand Up @@ -132,6 +132,17 @@ const (
grantTypeImplicit = "implicit"
grantTypePassword = "password"
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange"
)

const (
// https://www.rfc-editor.org/rfc/rfc8693.html#section-3
tokenTypeAccess = "urn:ietf:params:oauth:token-type:access_token"
tokenTypeRefresh = "urn:ietf:params:oauth:token-type:refresh_token"
tokenTypeID = "urn:ietf:params:oauth:token-type:id_token"
tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1"
tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2"
tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt"
)

const (
Expand Down

0 comments on commit 5aebff1

Please sign in to comment.