diff --git a/admin/auth.go b/admin/auth.go index 3cc53240..52129a96 100644 --- a/admin/auth.go +++ b/admin/auth.go @@ -47,62 +47,61 @@ func handlerAuthCheck(h http.Handler) http.Handler { case settings.AuthSAML: _, err := samlMiddleware.Session.GetSession(r) if err != nil { - cookiev, err := r.Cookie(samlConfig.TokenName) - if err != nil { - log.Printf("error extracting JWT data: %v", err) - http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) + log.Printf("GetSession %v", err) + } + cookiev, err := r.Cookie(samlConfig.TokenName) + if err != nil { + log.Printf("error extracting JWT data: %v", err) + http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) + return + } + jwtdata, err := parseJWTFromCookie(samlData.KeyPair, cookiev.Value) + if err != nil { + log.Printf("error parsing JWT: %v", err) + http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) + return + } + // Check if user is already authenticated + authenticated, session := sessionsmgr.CheckAuth(r) + if !authenticated { + // Create user if it does not exist + if !adminUsers.Exists(jwtdata.Username) { + log.Printf("user not found: %s", jwtdata.Username) + http.Redirect(w, r, forbiddenPath, http.StatusFound) return } - jwtdata, err := parseJWTFromCookie(samlData.KeyPair, cookiev.Value) + u, err := adminUsers.Get(jwtdata.Username) if err != nil { - log.Printf("error parsing JWT: %v", err) - http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) + log.Printf("error getting user %s: %v", jwtdata.Username, err) + http.Redirect(w, r, forbiddenPath, http.StatusFound) return } - // Check if user is already authenticated - authenticated, session := sessionsmgr.CheckAuth(r) - if !authenticated { - // Create user if it does not exist - if !adminUsers.Exists(jwtdata.Username) { - log.Printf("user not found: %s", jwtdata.Username) - http.Redirect(w, r, forbiddenPath, http.StatusFound) - return - } - u, err := adminUsers.Get(jwtdata.Username) - if err != nil { - log.Printf("error getting user %s: %v", jwtdata.Username, err) - http.Redirect(w, r, forbiddenPath, http.StatusFound) - return - } - access, err := adminUsers.GetEnvAccess(u.Username, u.DefaultEnv) - if err != nil { - log.Printf("error getting access for %s: %v", jwtdata.Username, err) - http.Redirect(w, r, forbiddenPath, http.StatusFound) - return - } - // Create new session - session, err = sessionsmgr.Save(r, w, u, access) - if err != nil { - log.Printf("session error: %v", err) - http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) - return - } + access, err := adminUsers.GetEnvAccess(u.Username, u.DefaultEnv) + if err != nil { + log.Printf("error getting access for %s: %v", jwtdata.Username, err) + http.Redirect(w, r, forbiddenPath, http.StatusFound) + return } - // Set middleware values - s := make(sessions.ContextValue) - s[ctxUser] = session.Username - s[ctxCSRF] = session.Values[ctxCSRF].(string) - ctx := context.WithValue(r.Context(), sessions.ContextKey("session"), s) - // Update metadata for the user - err = adminUsers.UpdateMetadata(session.IPAddress, session.UserAgent, session.Username, s["csrftoken"]) + // Create new session + session, err = sessionsmgr.Save(r, w, u, access) if err != nil { - log.Printf("error updating metadata for user %s: %v", session.Username, err) + log.Printf("session error: %v", err) + http.Redirect(w, r, samlConfig.LoginURL, http.StatusFound) + return } - // Access granted - samlMiddleware.RequireAccount(h).ServeHTTP(w, r.WithContext(ctx)) - } else { - samlMiddleware.RequireAccount(h).ServeHTTP(w, r) } + // Set middleware values + s := make(sessions.ContextValue) + s[ctxUser] = session.Username + s[ctxCSRF] = session.Values[ctxCSRF].(string) + ctx := context.WithValue(r.Context(), sessions.ContextKey("session"), s) + // Update metadata for the user + err = adminUsers.UpdateMetadata(session.IPAddress, session.UserAgent, session.Username, s["csrftoken"]) + if err != nil { + log.Printf("error updating metadata for user %s: %v", session.Username, err) + } + // Access granted + samlMiddleware.RequireAccount(h).ServeHTTP(w, r.WithContext(ctx)) } }) } diff --git a/admin/jwt.go b/admin/jwt.go index 34bfd672..4a530da3 100644 --- a/admin/jwt.go +++ b/admin/jwt.go @@ -3,7 +3,6 @@ package main import ( "crypto/rsa" "crypto/tls" - "crypto/x509" "log" "github.com/golang-jwt/jwt/v4" @@ -38,16 +37,15 @@ func parseJWTFromCookie(keypair tls.Certificate, cookie string) (JWTData, error) } tokenClaims := TokenClaims{} token, err := jwt.ParseWithClaims(cookie, &tokenClaims, func(t *jwt.Token) (interface{}, error) { - secretBlock := x509.MarshalPKCS1PrivateKey(keypair.PrivateKey.(*rsa.PrivateKey)) - return secretBlock, nil + return keypair.PrivateKey.(*rsa.PrivateKey).Public(), nil }) + if err != nil || !token.Valid { return JWTData{}, err } return JWTData{ Subject: tokenClaims.Subject, - Email: tokenClaims.Attributes["mail"][0], - Display: tokenClaims.Attributes["displayName"][0], - Username: tokenClaims.Attributes["sAMAccountName"][0], + Email: tokenClaims.Subject, + Username: tokenClaims.Subject, }, nil } diff --git a/admin/main.go b/admin/main.go index 7756074c..568f000e 100644 --- a/admin/main.go +++ b/admin/main.go @@ -574,6 +574,7 @@ func osctrlAdminService() { URL: *samlData.RootURL, Key: samlData.KeyPair.PrivateKey.(*rsa.PrivateKey), Certificate: samlData.KeyPair.Leaf, + IDPMetadata: samlData.IdpMetadata, AllowIDPInitiated: true, }) if err != nil { diff --git a/admin/saml.go b/admin/saml.go index f0267fe2..7d18ba21 100644 --- a/admin/saml.go +++ b/admin/saml.go @@ -1,12 +1,16 @@ package main import ( + "context" "crypto/tls" "crypto/x509" "fmt" "log" + "net/http" "net/url" + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" "github.com/jmpsec/osctrl/settings" "github.com/spf13/viper" ) @@ -28,6 +32,7 @@ type JSONConfigurationSAML struct { type samlThings struct { RootURL *url.URL IdpMetadataURL *url.URL + IdpMetadata *saml.EntityDescriptor KeyPair tls.Certificate } @@ -65,6 +70,10 @@ func keypairSAML(config JSONConfigurationSAML) (samlThings, error) { if err != nil { return data, fmt.Errorf("Parse MetadataURL %v", err) } + data.IdpMetadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *data.IdpMetadataURL) + if err != nil { + return data, fmt.Errorf("Fetch Metadata %v", err) + } data.RootURL, err = url.Parse(config.RootURL) if err != nil { return data, fmt.Errorf("Parse RootURL %v", err)