From f2c6bdac74ab74b5df9f5444a7f5dca667c8c7f6 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Wed, 29 Apr 2020 11:50:53 -0600 Subject: [PATCH 01/12] Adding support for getting public keys for JWTs from OIDC --- pkg/config/env.go | 3 + pkg/config/middleware.go | 139 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 138 insertions(+), 4 deletions(-) diff --git a/pkg/config/env.go b/pkg/config/env.go index 58396e725..16f2237b0 100644 --- a/pkg/config/env.go +++ b/pkg/config/env.go @@ -185,6 +185,9 @@ var Config = struct { // "HS256" and "RS256" supported JWTAuthSigningMethod string `env:"FLAGR_JWT_AUTH_SIGNING_METHOD" envDefault:"HS256"` + // Verify the JWT through OIDC flows + JWTAuthOIDCWellKnownURL string `env:"FLAGR_JWT_AUTH_OIDC_WELL_KNOWN_URL" envDefault:""` + // Identify users through headers HeaderAuthEnabled bool `env:"FLAGR_HEADER_AUTH_ENABLED" envDefault:"false"` HeaderAuthUserField string `env:"FLAGR_HEADER_AUTH_USER_FIELD" envDefault:"X-Email"` diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 0c4aea9b0..64ba10506 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -2,7 +2,10 @@ package config import ( "crypto/subtle" + "encoding/json" + "errors" "fmt" + "io/ioutil" "net/http" "strconv" "strings" @@ -143,14 +146,142 @@ func setupJWTAuthMiddleware() *jwtAuth { validationKey = []byte("") } + var validationKeyGetter = func(token *jwt.Token) (interface{}, error) { + return validationKey, errParsingKey + } + + if Config.JWTAuthOIDCWellKnownURL != "" { + validationKeyGetter = func(token *jwt.Token) (interface{}, error) { + + // If this is truly an OIDC token, it should have a "kid" header in the token. + var tokenKeyID = token.Header["kid"] + if tokenKeyID == nil { + return "", errors.New("Missing key id in the JWT") + } + + // First we need to access the well known url to find out what the jwk_uri is. + logrus.Printf("Fetching Oidc Configuration...") + resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL) + if err != nil { + return "", err + } + + // Read in the oidc config information, and unmarshal it into our oidc config. + logrus.Printf("Reading Oidc Configuration...") + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + logrus.Printf("Parsing Oidc Configuration...") + type OidcConfiguation struct { + JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. + } + var oidcConfig OidcConfiguation + json.Unmarshal([]byte(body), &oidcConfig) + logrus.Printf("OIDC Config: " + string(body)) + + if oidcConfig.JwkURI == "" { + return "", errors.New("OIDC configuration didn't contain a valid jwks_uri") + } + + logrus.Printf("Fetching Oidc Jwks...") + // Now we need to read in the jwks info. + oidcJwksResp, err := http.Get(oidcConfig.JwkURI) + if err != nil { + return "", err + } + + if err != nil { + return "", err + } + + logrus.Printf("Reading Oidc Jwks...") + // Read in the jwks and unmarshall it. + defer oidcJwksResp.Body.Close() + body, err = ioutil.ReadAll(oidcJwksResp.Body) + if err != nil { + return "", err + } + + type OidcJwk struct { + JWKeyID string `json:"kid"` // JWTs will have a "kid" that will match one of these + SigningKeys []string `json:"x5c"` // We only support x5c at the moment. + } + + type OidcJwksConfiguration struct { + Jwks []OidcJwk `json:"keys"` + } + + var oidcJwks OidcJwksConfiguration + + logrus.Printf("Parsing Oidc Jwks...") + json.Unmarshal([]byte(body), &oidcJwks) + logrus.Printf("OIDC JWKS: " + string(body)) + + logrus.Printf("Finding matching key") + // Find the key that matches the one in the JWT + if oidcJwks.Jwks == nil { + return "", errors.New("Missing keys in the jwk config") + } + + var numKeys = len(oidcJwks.Jwks) + if numKeys == 0 { + return "", errors.New("No keys in the jwk config") + } + + matchingKey := -1 + for currentKey := 0; currentKey < numKeys; currentKey++ { + if oidcJwks.Jwks[currentKey].JWKeyID == tokenKeyID { + matchingKey = currentKey + break + } + } + + if matchingKey == -1 { + return "", errors.New("No matching key found in the jwk config") + } + + correctKey := oidcJwks.Jwks[matchingKey].SigningKeys[0] + logrus.Printf("Found key: " + correctKey) + + // Now we will take the first key and convert it into a cert + // that the library we user are familiar with. This cert requires + // a very specific format that is dependent on whitespace :/ + // There can only be 64 characters on a line or else it won't read it + // so we have to do that manually. Also the BEGIN and END lines need to + // be their own lines. + logrus.Printf("Building cert!") + numCharsInKey := len(correctKey) + var numLines int = numCharsInKey / 64 + if numLines%64 > 0 { + numLines++ + } + + var jwtCert string = "-----BEGIN PUBLIC KEY-----\n" + for currentLine := 0; currentLine < numLines; currentLine++ { + startCharacter := currentLine * 64 + endCharacter := startCharacter + 64 + if startCharacter+64 > numCharsInKey { + endCharacter = numCharsInKey + } + + jwtCert += correctKey[startCharacter:endCharacter] + "\n" + } + jwtCert += "-----END PUBLIC KEY-----" + logrus.Printf("Cert created: \n" + jwtCert) + + return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert)) + } + } + return &jwtAuth{ PrefixWhitelistPaths: Config.JWTAuthPrefixWhitelistPaths, ExactWhitelistPaths: Config.JWTAuthExactWhitelistPaths, JWTMiddleware: jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return validationKey, errParsingKey - }, - SigningMethod: signingMethod, + ValidationKeyGetter: validationKeyGetter, + SigningMethod: signingMethod, Extractor: jwtmiddleware.FromFirst( func(r *http.Request) (string, error) { c, err := r.Cookie(Config.JWTAuthCookieTokenName) From 3b4720f082597ae4a5bc6fea3bfd794d2b6188e2 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Wed, 29 Apr 2020 13:58:53 -0600 Subject: [PATCH 02/12] Removing logging --- pkg/config/middleware.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 64ba10506..5e578e097 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -160,33 +160,28 @@ func setupJWTAuthMiddleware() *jwtAuth { } // First we need to access the well known url to find out what the jwk_uri is. - logrus.Printf("Fetching Oidc Configuration...") resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL) if err != nil { return "", err } // Read in the oidc config information, and unmarshal it into our oidc config. - logrus.Printf("Reading Oidc Configuration...") defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { return "", err } - logrus.Printf("Parsing Oidc Configuration...") type OidcConfiguation struct { JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. } var oidcConfig OidcConfiguation json.Unmarshal([]byte(body), &oidcConfig) - logrus.Printf("OIDC Config: " + string(body)) if oidcConfig.JwkURI == "" { return "", errors.New("OIDC configuration didn't contain a valid jwks_uri") } - logrus.Printf("Fetching Oidc Jwks...") // Now we need to read in the jwks info. oidcJwksResp, err := http.Get(oidcConfig.JwkURI) if err != nil { @@ -197,7 +192,6 @@ func setupJWTAuthMiddleware() *jwtAuth { return "", err } - logrus.Printf("Reading Oidc Jwks...") // Read in the jwks and unmarshall it. defer oidcJwksResp.Body.Close() body, err = ioutil.ReadAll(oidcJwksResp.Body) @@ -216,11 +210,8 @@ func setupJWTAuthMiddleware() *jwtAuth { var oidcJwks OidcJwksConfiguration - logrus.Printf("Parsing Oidc Jwks...") json.Unmarshal([]byte(body), &oidcJwks) - logrus.Printf("OIDC JWKS: " + string(body)) - logrus.Printf("Finding matching key") // Find the key that matches the one in the JWT if oidcJwks.Jwks == nil { return "", errors.New("Missing keys in the jwk config") @@ -244,7 +235,6 @@ func setupJWTAuthMiddleware() *jwtAuth { } correctKey := oidcJwks.Jwks[matchingKey].SigningKeys[0] - logrus.Printf("Found key: " + correctKey) // Now we will take the first key and convert it into a cert // that the library we user are familiar with. This cert requires @@ -252,7 +242,6 @@ func setupJWTAuthMiddleware() *jwtAuth { // There can only be 64 characters on a line or else it won't read it // so we have to do that manually. Also the BEGIN and END lines need to // be their own lines. - logrus.Printf("Building cert!") numCharsInKey := len(correctKey) var numLines int = numCharsInKey / 64 if numLines%64 > 0 { @@ -270,7 +259,6 @@ func setupJWTAuthMiddleware() *jwtAuth { jwtCert += correctKey[startCharacter:endCharacter] + "\n" } jwtCert += "-----END PUBLIC KEY-----" - logrus.Printf("Cert created: \n" + jwtCert) return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert)) } From 26636cf3afbedfdd8f11e2cf1175052329e86867 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Wed, 29 Apr 2020 16:48:49 -0600 Subject: [PATCH 03/12] Removing duplicated statement --- pkg/config/middleware.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 5e578e097..025cc3fa4 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -188,10 +188,6 @@ func setupJWTAuthMiddleware() *jwtAuth { return "", err } - if err != nil { - return "", err - } - // Read in the jwks and unmarshall it. defer oidcJwksResp.Body.Close() body, err = ioutil.ReadAll(oidcJwksResp.Body) From 9c23d5a378200ed3d0b8d8b16c5f844af279684e Mon Sep 17 00:00:00 2001 From: dseabolt Date: Tue, 5 May 2020 16:09:50 -0600 Subject: [PATCH 04/12] Adding support for OIDC keys that only specify modulus and exponent of private key --- pkg/config/middleware.go | 268 +++++++++++++++++++++++++-------------- 1 file changed, 173 insertions(+), 95 deletions(-) diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 025cc3fa4..538ec8685 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -1,11 +1,16 @@ package config import ( + "bytes" + "crypto/rsa" "crypto/subtle" + "encoding/base64" + "encoding/binary" "encoding/json" "errors" "fmt" "io/ioutil" + "math/big" "net/http" "strconv" "strings" @@ -123,6 +128,165 @@ func setupRecoveryMiddleware() *negroni.Recovery { return r } +type OidcConfiguation struct { + JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. +} + +type OidcJwk struct { + JWKeyID string `json:"kid"` // JWTs will have a "kid" that will match one of these + SigningKeys []string `json:"x5c"` // We prefer the x5c since it's easier + Exponent string `json:"e"` // The least preferred way is recreating the key through exponent. + Modulus string `json:"n"` // The modulus +} + +type OidcJwksConfiguration struct { + Jwks []OidcJwk `json:"keys"` +} + +func discoverOidcJwk(tokenKeyID string) *OidcJwk { + // First we need to access the well known url to find out what the jwk_uri is. + resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL) + if err != nil { + logrus.Errorln("Failed to contact OIDC well known URL") + return nil + } + + // Read in the oidc config information, and unmarshal it into our oidc config. + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + logrus.Errorln("Failed to read from OIDC well known URL") + return nil + } + + var oidcConfig OidcConfiguation + json.Unmarshal([]byte(body), &oidcConfig) + + if oidcConfig.JwkURI == "" { + logrus.Errorln("OIDC configuration didn't contain a valid jwks_uri") + return nil + } + + // Now we need to read in the jwks info. + oidcJwksResp, err := http.Get(oidcConfig.JwkURI) + if err != nil { + logrus.Errorln("Failed to contact the JWK URI") + return nil + } + + // Read in the jwks and unmarshall it. + defer oidcJwksResp.Body.Close() + body, err = ioutil.ReadAll(oidcJwksResp.Body) + if err != nil { + logrus.Errorln("Failed to read the JWKs body") + return nil + } + + var oidcJwks OidcJwksConfiguration + + json.Unmarshal([]byte(body), &oidcJwks) + + // Find the key that matches the one in the JWT + if oidcJwks.Jwks == nil { + logrus.Errorln("Missing keys in the jwk config") + return nil + } + + var numKeys = len(oidcJwks.Jwks) + if numKeys == 0 { + logrus.Errorln("No keys in the jwk config") + return nil + } + + matchingKey := -1 + for currentKey := 0; currentKey < numKeys; currentKey++ { + if oidcJwks.Jwks[currentKey].JWKeyID == tokenKeyID { + matchingKey = currentKey + break + } + } + + if matchingKey == -1 { + logrus.Errorln("No matching key found in the jwk config") + return nil + } + + return &oidcJwks.Jwks[matchingKey] +} + +func extractJWTSigningKeyFromX5C(oidcJwk *OidcJwk) (interface{}, error) { + correctKey := oidcJwk.SigningKeys[0] + + // Now we will take the first key and convert it into a cert + // that the library we user are familiar with. This cert requires + // a very specific format that is dependent on whitespace :/ + // There can only be 64 characters on a line or else it won't read it + // so we have to do that manually. Also the BEGIN and END lines need to + // be their own lines. + numCharsInKey := len(correctKey) + var numLines int = numCharsInKey / 64 + if numLines%64 > 0 { + numLines++ + } + + var jwtCert string = "-----BEGIN PUBLIC KEY-----\n" + for currentLine := 0; currentLine < numLines; currentLine++ { + startCharacter := currentLine * 64 + endCharacter := startCharacter + 64 + if startCharacter+64 > numCharsInKey { + endCharacter = numCharsInKey + } + + jwtCert += correctKey[startCharacter:endCharacter] + "\n" + } + jwtCert += "-----END PUBLIC KEY-----" + + return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert)) +} + +func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { + if jwk.Modulus == "" || jwk.Exponent == "" { + return "", errors.New("Invalid Modulus or Exponent provided") + } + + // Decode the modulus and move it into a big int. + logrus.Printf("Modulus found: " + jwk.Modulus) + decN, err := base64.RawURLEncoding.DecodeString(jwk.Modulus) + if err != nil { + logrus.Errorf("Failed to decode modulus string.") + return "", err + } + n := big.NewInt(0) + n.SetBytes(decN) + + // Decode the exponent + logrus.Printf("Exponent found: " + jwk.Exponent) + eStr := jwk.Exponent + decE, err := base64.RawURLEncoding.DecodeString(eStr) + if err != nil { + logrus.Errorf("Failed to decode exponent string") + return "", err + } + + var eBytes []byte + if len(decE) < 8 { + eBytes = make([]byte, 8-len(decE), 8) + eBytes = append(eBytes, decE...) + } else { + eBytes = decE + } + eReader := bytes.NewReader(eBytes) + var e uint64 + err = binary.Read(eReader, binary.BigEndian, &e) + if err != nil { + logrus.Errorf("Failed to read exponent bytes") + return "", err + } + pKey := &rsa.PublicKey{N: n, E: int(e)} + logrus.Printf("Created public key.") + return pKey, nil +} + /** setupJWTAuthMiddleware setup an JWTMiddleware from the ENV config */ @@ -159,104 +323,18 @@ func setupJWTAuthMiddleware() *jwtAuth { return "", errors.New("Missing key id in the JWT") } - // First we need to access the well known url to find out what the jwk_uri is. - resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL) - if err != nil { - return "", err - } - - // Read in the oidc config information, and unmarshal it into our oidc config. - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - - type OidcConfiguation struct { - JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. - } - var oidcConfig OidcConfiguation - json.Unmarshal([]byte(body), &oidcConfig) - - if oidcConfig.JwkURI == "" { - return "", errors.New("OIDC configuration didn't contain a valid jwks_uri") - } - - // Now we need to read in the jwks info. - oidcJwksResp, err := http.Get(oidcConfig.JwkURI) - if err != nil { - return "", err - } - - // Read in the jwks and unmarshall it. - defer oidcJwksResp.Body.Close() - body, err = ioutil.ReadAll(oidcJwksResp.Body) - if err != nil { - return "", err - } - - type OidcJwk struct { - JWKeyID string `json:"kid"` // JWTs will have a "kid" that will match one of these - SigningKeys []string `json:"x5c"` // We only support x5c at the moment. - } - - type OidcJwksConfiguration struct { - Jwks []OidcJwk `json:"keys"` + var jwk = discoverOidcJwk(tokenKeyID.(string)) + if jwk == nil { + return "", errors.New("Failed to find JWK for JWT") } - var oidcJwks OidcJwksConfiguration - - json.Unmarshal([]byte(body), &oidcJwks) - - // Find the key that matches the one in the JWT - if oidcJwks.Jwks == nil { - return "", errors.New("Missing keys in the jwk config") + if len(jwk.SigningKeys) > 0 { + return extractJWTSigningKeyFromX5C(jwk) + } else if jwk.Exponent != "" && jwk.Modulus != "" { + return calculateJWTSigningKey(jwk) + } else { + return "", errors.New("JWK is invalid") } - - var numKeys = len(oidcJwks.Jwks) - if numKeys == 0 { - return "", errors.New("No keys in the jwk config") - } - - matchingKey := -1 - for currentKey := 0; currentKey < numKeys; currentKey++ { - if oidcJwks.Jwks[currentKey].JWKeyID == tokenKeyID { - matchingKey = currentKey - break - } - } - - if matchingKey == -1 { - return "", errors.New("No matching key found in the jwk config") - } - - correctKey := oidcJwks.Jwks[matchingKey].SigningKeys[0] - - // Now we will take the first key and convert it into a cert - // that the library we user are familiar with. This cert requires - // a very specific format that is dependent on whitespace :/ - // There can only be 64 characters on a line or else it won't read it - // so we have to do that manually. Also the BEGIN and END lines need to - // be their own lines. - numCharsInKey := len(correctKey) - var numLines int = numCharsInKey / 64 - if numLines%64 > 0 { - numLines++ - } - - var jwtCert string = "-----BEGIN PUBLIC KEY-----\n" - for currentLine := 0; currentLine < numLines; currentLine++ { - startCharacter := currentLine * 64 - endCharacter := startCharacter + 64 - if startCharacter+64 > numCharsInKey { - endCharacter = numCharsInKey - } - - jwtCert += correctKey[startCharacter:endCharacter] + "\n" - } - jwtCert += "-----END PUBLIC KEY-----" - - return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert)) } } From 8b863b9ec855f245a098cf689eff1c0fb86ee256 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Thu, 7 May 2020 15:18:03 -0600 Subject: [PATCH 05/12] Stripping the eval context from batch responses in order to reduce the size of the response tremendously --- docs/api_docs/bundle.yaml | 3 ++ pkg/handler/eval.go | 41 +++++++++++++------ pkg/handler/eval_test.go | 20 ++++----- swagger/index.yaml | 3 ++ .../models/evaluation_batch_request.go | 3 ++ swagger_gen/restapi/embedded_spec.go | 8 ++++ 6 files changed, 56 insertions(+), 22 deletions(-) diff --git a/docs/api_docs/bundle.yaml b/docs/api_docs/bundle.yaml index 7e8512b08..24d4d502e 100644 --- a/docs/api_docs/bundle.yaml +++ b/docs/api_docs/bundle.yaml @@ -1197,6 +1197,9 @@ definitions: minItems: 1 enableDebug: type: boolean + stripEvalContext: + type: boolean + default: false flagIDs: description: flagIDs type: array diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index b1be769d5..d01cc2db4 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -39,7 +39,7 @@ func (e *eval) PostEvaluation(params evaluation.PostEvaluationParams) middleware ErrorMessage("empty body")) } - evalResult := EvalFlag(*evalContext) + evalResult := EvalFlag(*evalContext, false) resp := evaluation.NewPostEvaluationOK() resp.SetPayload(evalResult) return resp @@ -51,28 +51,38 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) flagKeys := params.Body.FlagKeys results := &models.EvaluationBatchResponse{} + stripEvalContextFromResults := false + if params.Body.StripEvalContext != nil && *params.Body.StripEvalContext { + stripEvalContextFromResults = *params.Body.StripEvalContext + } + // TODO make it concurrent for _, entity := range entities { for _, flagID := range flagIDs { - evalContext := models.EvalContext{ + + var evalContext models.EvalContext + evalContext = models.EvalContext{ EnableDebug: params.Body.EnableDebug, EntityContext: entity.EntityContext, EntityID: entity.EntityID, EntityType: entity.EntityType, FlagID: flagID, } - evalResult := EvalFlag(evalContext) + + evalResult := EvalFlag(evalContext, stripEvalContextFromResults) results.EvaluationResults = append(results.EvaluationResults, evalResult) } for _, flagKey := range flagKeys { - evalContext := models.EvalContext{ + var evalContext models.EvalContext + evalContext = models.EvalContext{ EnableDebug: params.Body.EnableDebug, EntityContext: entity.EntityContext, EntityID: entity.EntityID, EntityType: entity.EntityType, FlagKey: flagKey, } - evalResult := EvalFlag(evalContext) + + evalResult := EvalFlag(evalContext, stripEvalContextFromResults) results.EvaluationResults = append(results.EvaluationResults, evalResult) } } @@ -83,7 +93,7 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) } // BlankResult creates a blank result -func BlankResult(f *entity.Flag, evalContext models.EvalContext, msg string) *models.EvalResult { +func BlankResult(f *entity.Flag, evalContext *models.EvalContext, msg string) *models.EvalResult { flagID := uint(0) flagKey := "" flagSnapshotID := uint(0) @@ -93,7 +103,7 @@ func BlankResult(f *entity.Flag, evalContext models.EvalContext, msg string) *mo flagKey = f.Key } return &models.EvalResult{ - EvalContext: &evalContext, + EvalContext: evalContext, EvalDebugLog: &models.EvalDebugLog{ Msg: msg, SegmentDebugLogs: nil, @@ -105,26 +115,33 @@ func BlankResult(f *entity.Flag, evalContext models.EvalContext, msg string) *mo } } -var EvalFlag = func(evalContext models.EvalContext) *models.EvalResult { +// Evaluates a flag for a given context and determines what segment, if any, that applies. +var EvalFlag = func(evalContext models.EvalContext, stripEvalContextFromResults bool) *models.EvalResult { cache := GetEvalCache() flagID := util.SafeUint(evalContext.FlagID) flagKey := util.SafeString(evalContext.FlagKey) f := cache.GetByFlagKeyOrID(flagID) + + outputEvalContext := &evalContext + if stripEvalContextFromResults { + outputEvalContext = nil + } + if f == nil { f = cache.GetByFlagKeyOrID(flagKey) } if f == nil { emptyFlag := &entity.Flag{Model: gorm.Model{ID: flagID}, Key: flagKey} - return BlankResult(emptyFlag, evalContext, fmt.Sprintf("flagID %v not found or deleted", flagID)) + return BlankResult(emptyFlag, outputEvalContext, fmt.Sprintf("flagID %v not found or deleted", flagID)) } if !f.Enabled { - return BlankResult(f, evalContext, fmt.Sprintf("flagID %v is not enabled", f.ID)) + return BlankResult(f, outputEvalContext, fmt.Sprintf("flagID %v is not enabled", f.ID)) } if len(f.Segments) == 0 { - return BlankResult(f, evalContext, fmt.Sprintf("flagID %v has no segments", f.ID)) + return BlankResult(f, outputEvalContext, fmt.Sprintf("flagID %v has no segments", f.ID)) } if evalContext.EntityID == "" { @@ -152,7 +169,7 @@ var EvalFlag = func(evalContext models.EvalContext) *models.EvalResult { break } } - evalResult := BlankResult(f, evalContext, "") + evalResult := BlankResult(f, outputEvalContext, "") evalResult.EvalDebugLog.SegmentDebugLogs = logs evalResult.SegmentID = sID evalResult.VariantID = vID diff --git a/pkg/handler/eval_test.go b/pkg/handler/eval_test.go index 3455d4ea9..1748c621f 100644 --- a/pkg/handler/eval_test.go +++ b/pkg/handler/eval_test.go @@ -146,7 +146,7 @@ func TestEvalFlag(t *testing.T) { t.Run("test empty evalContext", func(t *testing.T) { defer gostub.StubFunc(&GetEvalCache, GenFixtureEvalCache()).Reset() - result := EvalFlag(models.EvalContext{FlagID: int64(100)}) + result := EvalFlag(models.EvalContext{FlagID: int64(100)}, false) assert.Zero(t, result.VariantID) assert.NotZero(t, result.FlagID) assert.NotEmpty(t, result.EvalContext.EntityID) @@ -160,7 +160,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) }) @@ -173,7 +173,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagKey: "flag_key_100", - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) }) @@ -186,7 +186,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagKey: "flag_key_100", - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) }) @@ -225,7 +225,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) }) @@ -245,7 +245,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -277,7 +277,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -293,7 +293,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -310,7 +310,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) assert.Equal(t, "entityType1", result.EvalContext.EntityType) @@ -326,7 +326,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotNil(t, result.VariantID) assert.NotEqual(t, "entityType1", result.EvalContext.EntityType) diff --git a/swagger/index.yaml b/swagger/index.yaml index eea09f40c..dcf352fb4 100644 --- a/swagger/index.yaml +++ b/swagger/index.yaml @@ -485,6 +485,9 @@ definitions: minItems: 1 enableDebug: type: boolean + stripEvalContext: + type: boolean + default: false flagIDs: description: flagIDs type: array diff --git a/swagger_gen/models/evaluation_batch_request.go b/swagger_gen/models/evaluation_batch_request.go index c819134ec..69808dac4 100644 --- a/swagger_gen/models/evaluation_batch_request.go +++ b/swagger_gen/models/evaluation_batch_request.go @@ -34,6 +34,9 @@ type EvaluationBatchRequest struct { // flagKeys. Either flagIDs or flagKeys works. If pass in both, Flagr may return duplicate results. // Min Items: 1 FlagKeys []string `json:"flagKeys"` + + // strip eval context + StripEvalContext *bool `json:"stripEvalContext,omitempty"` } // Validate validates this evaluation batch request diff --git a/swagger_gen/restapi/embedded_spec.go b/swagger_gen/restapi/embedded_spec.go index 6a2f88f5f..7f4c6e851 100644 --- a/swagger_gen/restapi/embedded_spec.go +++ b/swagger_gen/restapi/embedded_spec.go @@ -1424,6 +1424,10 @@ func init() { "type": "string", "minLength": 1 } + }, + "stripEvalContext": { + "type": "boolean", + "default": false } } }, @@ -3203,6 +3207,10 @@ func init() { "type": "string", "minLength": 1 } + }, + "stripEvalContext": { + "type": "boolean", + "default": false } } }, From beabf1f0caf98703ecd421502d0f6cc664155601 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Thu, 7 May 2020 15:43:03 -0600 Subject: [PATCH 06/12] Removing print statements --- pkg/config/middleware.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 538ec8685..d626ca589 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -250,7 +250,6 @@ func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { } // Decode the modulus and move it into a big int. - logrus.Printf("Modulus found: " + jwk.Modulus) decN, err := base64.RawURLEncoding.DecodeString(jwk.Modulus) if err != nil { logrus.Errorf("Failed to decode modulus string.") @@ -260,7 +259,6 @@ func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { n.SetBytes(decN) // Decode the exponent - logrus.Printf("Exponent found: " + jwk.Exponent) eStr := jwk.Exponent decE, err := base64.RawURLEncoding.DecodeString(eStr) if err != nil { @@ -283,7 +281,6 @@ func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { return "", err } pKey := &rsa.PublicKey{N: n, E: int(e)} - logrus.Printf("Created public key.") return pKey, nil } From c9da7158c2d6b486a0d226a3818d19100ab49126 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Thu, 14 May 2020 13:21:46 -0600 Subject: [PATCH 07/12] Adding support for Prometheus namespaces --- pkg/config/config.go | 15 +++++++++------ pkg/config/env.go | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 22401cdf9..60114f1d3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -115,18 +115,21 @@ func setupPrometheus() { if Config.PrometheusEnabled { Global.Prometheus.ScrapePath = Config.PrometheusPath Global.Prometheus.EvalCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "flagr_eval_results", - Help: "A counter of eval results", + Namespace: Config.PrometheusNamespace, + Name: "flagr_eval_results", + Help: "A counter of eval results", }, []string{"EntityType", "FlagID", "VariantID", "VariantKey"}) Global.Prometheus.RequestCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "flagr_requests_total", - Help: "The total http requests received", + Namespace: Config.PrometheusNamespace, + Name: "flagr_requests_total", + Help: "The total http requests received", }, []string{"status", "path", "method"}) if Config.PrometheusIncludeLatencyHistogram { Global.Prometheus.RequestHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ - Name: "flagr_requests_buckets", - Help: "A histogram of latencies for requests received", + Namespace: Config.PrometheusNamespace, + Name: "flagr_requests_buckets", + Help: "A histogram of latencies for requests received", }, []string{"status", "path", "method"}) } } diff --git a/pkg/config/env.go b/pkg/config/env.go index 16f2237b0..954deb2c8 100644 --- a/pkg/config/env.go +++ b/pkg/config/env.go @@ -94,6 +94,8 @@ var Config = struct { PrometheusEnabled bool `env:"FLAGR_PROMETHEUS_ENABLED" envDefault:"false"` // PrometheusPath - set the path on which prometheus metrics are available to scrape PrometheusPath string `env:"FLAGR_PROMETHEUS_PATH" envDefault:"/metrics"` + // PrometheusNamespace - set an optional namespace to prepend to all stats + PrometheusNamespace string `env:"FLAGR_PROMETHEUS_NAMESPACE" envDefault:""` // PrometheusIncludeLatencyHistogram - set whether Prometheus should also export a histogram of request latencies (this increases cardinality significantly) PrometheusIncludeLatencyHistogram bool `env:"FLAGR_PROMETHEUS_INCLUDE_LATENCY_HISTOGRAM" envDefault:"false"` From 12ad5a08b5dbb131c10d244212e56dc83b2eb219 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Fri, 12 Jun 2020 16:33:53 -0600 Subject: [PATCH 08/12] Fixing a crash in prometheus stats when we've stripped the eval context --- pkg/handler/eval.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index d01cc2db4..0b82b6296 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -224,13 +224,21 @@ var logEvalResultToPrometheus = func(r *models.EvalResult) { if config.Global.Prometheus.EvalCounter == nil { return } - config.Global.Prometheus.EvalCounter.WithLabelValues( - util.SafeStringWithDefault(r.EvalContext.EntityType, "null"), - util.SafeStringWithDefault(r.FlagID, "null"), - util.SafeStringWithDefault(r.VariantID, "null"), - util.SafeStringWithDefault(r.VariantKey, "null"), - ).Inc() + if r.EvalContext != nil { + config.Global.Prometheus.EvalCounter.WithLabelValues( + util.SafeStringWithDefault(r.FlagID, "null"), + util.SafeStringWithDefault(r.VariantID, "null"), + util.SafeStringWithDefault(r.VariantKey, "null"), + ).Inc() + } else { + config.Global.Prometheus.EvalCounter.WithLabelValues( + util.SafeStringWithDefault(r.EvalContext.EntityType, "null"), + util.SafeStringWithDefault(r.FlagID, "null"), + util.SafeStringWithDefault(r.VariantID, "null"), + util.SafeStringWithDefault(r.VariantKey, "null"), + ).Inc() + } } var evalSegment = func( From 736b27eebdad09ba1d9a74e672c00866f35dbb14 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Fri, 12 Jun 2020 17:16:00 -0600 Subject: [PATCH 09/12] Fixing a crash in prometheus stats when we've stripped the eval context --- pkg/handler/eval.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index 0b82b6296..afd051b0b 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -225,7 +225,7 @@ var logEvalResultToPrometheus = func(r *models.EvalResult) { return } - if r.EvalContext != nil { + if r.EvalContext == nil { config.Global.Prometheus.EvalCounter.WithLabelValues( util.SafeStringWithDefault(r.FlagID, "null"), util.SafeStringWithDefault(r.VariantID, "null"), From 4b941cad57138ea9d977c0ed3681b3f86dd8ea53 Mon Sep 17 00:00:00 2001 From: dseabolt Date: Mon, 15 Jun 2020 15:54:26 -0600 Subject: [PATCH 10/12] Making prometheus have an empty label when we've stripped out the eval context --- pkg/handler/eval.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index afd051b0b..1b0dd83a5 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -227,6 +227,7 @@ var logEvalResultToPrometheus = func(r *models.EvalResult) { if r.EvalContext == nil { config.Global.Prometheus.EvalCounter.WithLabelValues( + util.SafeStringWithDefault("stripped", "null"), util.SafeStringWithDefault(r.FlagID, "null"), util.SafeStringWithDefault(r.VariantID, "null"), util.SafeStringWithDefault(r.VariantKey, "null"), From 00d61ead5fc4c2f852446104075815ccf1631053 Mon Sep 17 00:00:00 2001 From: Trav f Date: Fri, 8 Dec 2023 10:47:39 -0700 Subject: [PATCH 11/12] Updating Flagr-Source with source repo - openflagr/flagr --- pkg/handler/eval.go | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index 106f8af14..a97cfabb0 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -40,7 +40,7 @@ func (e *eval) PostEvaluation(params evaluation.PostEvaluationParams) middleware ErrorMessage("empty body")) } - evalResult := EvalFlag(*evalContext, false) + evalResult := EvalFlag(*evalContext) resp := evaluation.NewPostEvaluationOK() resp.SetPayload(evalResult) return resp @@ -54,11 +54,6 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) flagTagsOperator := params.Body.FlagTagsOperator results := &models.EvaluationBatchResponse{} - stripEvalContextFromResults := false - if params.Body.StripEvalContext != nil && *params.Body.StripEvalContext { - stripEvalContextFromResults = *params.Body.StripEvalContext - } - // TODO make it concurrent for _, entity := range entities { if len(flagTags) > 0 { @@ -74,9 +69,7 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) results.EvaluationResults = append(results.EvaluationResults, evalResults...) } for _, flagID := range flagIDs { - - var evalContext models.EvalContext - evalContext = models.EvalContext{ + evalContext := models.EvalContext{ EnableDebug: params.Body.EnableDebug, EntityContext: entity.EntityContext, EntityID: entity.EntityID, @@ -84,12 +77,11 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) FlagID: flagID, } - evalResult := EvalFlag(evalContext, stripEvalContextFromResults) + evalResult := EvalFlag(evalContext) results.EvaluationResults = append(results.EvaluationResults, evalResult) } for _, flagKey := range flagKeys { - var evalContext models.EvalContext - evalContext = models.EvalContext{ + evalContext := models.EvalContext{ EnableDebug: params.Body.EnableDebug, EntityContext: entity.EntityContext, EntityID: entity.EntityID, @@ -97,7 +89,7 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) FlagKey: flagKey, } - evalResult := EvalFlag(evalContext, stripEvalContextFromResults) + evalResult := EvalFlag(evalContext) results.EvaluationResults = append(results.EvaluationResults, evalResult) } } @@ -108,7 +100,7 @@ func (e *eval) PostEvaluationBatch(params evaluation.PostEvaluationBatchParams) } // BlankResult creates a blank result -func BlankResult(f *entity.Flag, evalContext *models.EvalContext, msg string) *models.EvalResult { +func BlankResult(f *entity.Flag, evalContext models.EvalContext, msg string) *models.EvalResult { flagID := uint(0) flagKey := "" flagSnapshotID := uint(0) @@ -118,7 +110,7 @@ func BlankResult(f *entity.Flag, evalContext *models.EvalContext, msg string) *m flagKey = f.Key } return &models.EvalResult{ - EvalContext: evalContext, + EvalContext: &evalContext, EvalDebugLog: &models.EvalDebugLog{ Msg: msg, SegmentDebugLogs: nil, @@ -130,18 +122,12 @@ func BlankResult(f *entity.Flag, evalContext *models.EvalContext, msg string) *m } } -// Evaluates a flag for a given context and determines what segment, if any, that applies. var LookupFlag = func(evalContext models.EvalContext) *entity.Flag { cache := GetEvalCache() flagID := util.SafeUint(evalContext.FlagID) flagKey := util.SafeString(evalContext.FlagKey) f := cache.GetByFlagKeyOrID(flagID) - outputEvalContext := &evalContext - if stripEvalContextFromResults { - outputEvalContext = nil - } - if f == nil { f = cache.GetByFlagKeyOrID(flagKey) } @@ -169,7 +155,7 @@ var EvalFlagWithContext = func(flag *entity.Flag, evalContext models.EvalContext if flag == nil { emptyFlag := &entity.Flag{Model: gorm.Model{ID: flagID}, Key: flagKey} - return BlankResult(emptyFlag, outputEvalContext, fmt.Sprintf("flagID %v not found or deleted", flagID)) + return BlankResult(emptyFlag, evalContext, fmt.Sprintf("flagID %v not found or deleted", flagID)) } if !flag.Enabled { From 0cd05c38c06b2b3bd1510bbc7b04ec8a14444bf0 Mon Sep 17 00:00:00 2001 From: Trav f Date: Fri, 8 Dec 2023 10:48:21 -0700 Subject: [PATCH 12/12] Attempt at blocking Transfer-Encoding header --- pkg/config/middleware.go | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 4b30b80a8..39cfeb8d5 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -43,6 +43,8 @@ func ServerShutdown() { func SetupGlobalMiddleware(handler http.Handler) http.Handler { n := negroni.New() + n.Use(&headerCheckMiddleware{}) + if Config.MiddlewareGzipEnabled { n.Use(gzip.Gzip(gzip.DefaultCompression)) } @@ -134,6 +136,25 @@ func setupRecoveryMiddleware() *negroni.Recovery { return r } +type headerCheckMiddleware struct{} + +func (hcm *headerCheckMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + // LOG A STATEMENT HERE + for key, values := range r.Header { + logrus.Infof("Header %s: %s", key, strings.Join(values, ", ")) + } + logrus.Infof("headers::") + logrus.Infof("headers: %#v\n", r.Header) + // logrus.Infof("requestttt") + // logrus.Infof(r) + + if r.Header.Get("Content-Length") != "" && r.Header.Get("Transfer-Encoding") != "" { + http.Error(w, "Invalid request with both Content-Length and Transfer-Encoding headers", http.StatusBadRequest) + return + } + next(w, r) +} + type OidcConfiguation struct { JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. } @@ -290,7 +311,8 @@ func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { return pKey, nil } -/** +/* +* setupJWTAuthMiddleware setup an JWTMiddleware from the ENV config */ func setupJWTAuthMiddleware() *jwtAuth { @@ -410,7 +432,8 @@ func (a *jwtAuth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. a.JWTMiddleware.HandlerWithNext(w, req, next) } -/** +/* +* setupBasicAuthMiddleware setup an BasicMiddleware from the ENV config */ func setupBasicAuthMiddleware() *basicAuth {