diff --git a/docs/api_docs/bundle.yaml b/docs/api_docs/bundle.yaml index a486b390..165a1c33 100644 --- a/docs/api_docs/bundle.yaml +++ b/docs/api_docs/bundle.yaml @@ -1408,6 +1408,9 @@ definitions: minItems: 1 enableDebug: type: boolean + stripEvalContext: + type: boolean + default: false flagIDs: description: flagIDs type: array diff --git a/pkg/config/config.go b/pkg/config/config.go index fd77f587..a37b7cf9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -118,18 +118,21 @@ func setupPrometheus() { if Config.PrometheusEnabled { Global.Prometheus.ScrapePath = Config.PrometheusPath Global.Prometheus.EvalCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "flagr_eval_results", - Help: "A counter of eval results", - }, []string{"EntityType", "FlagID", "FlagKey", "VariantID", "VariantKey"}) + Namespace: Config.PrometheusNamespace, + Name: "flagr_eval_results", + Help: "A counter of eval results", + }, []string{"EntityType", "FlagID", "VariantID", "VariantKey"}) Global.Prometheus.RequestCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "flagr_requests_total", - Help: "The total http requests received", + Namespace: Config.PrometheusNamespace, + Name: "flagr_requests_total", + Help: "The total http requests received", }, []string{"status", "path", "method"}) if Config.PrometheusIncludeLatencyHistogram { Global.Prometheus.RequestHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ - Name: "flagr_requests_buckets", - Help: "A histogram of latencies for requests received", + Namespace: Config.PrometheusNamespace, + Name: "flagr_requests_buckets", + Help: "A histogram of latencies for requests received", }, []string{"status", "path", "method"}) } } diff --git a/pkg/config/env.go b/pkg/config/env.go index 45bb69cd..5ab97741 100644 --- a/pkg/config/env.go +++ b/pkg/config/env.go @@ -102,6 +102,8 @@ var Config = struct { PrometheusEnabled bool `env:"FLAGR_PROMETHEUS_ENABLED" envDefault:"false"` // PrometheusPath - set the path on which prometheus metrics are available to scrape PrometheusPath string `env:"FLAGR_PROMETHEUS_PATH" envDefault:"/metrics"` + // PrometheusNamespace - set an optional namespace to prepend to all stats + PrometheusNamespace string `env:"FLAGR_PROMETHEUS_NAMESPACE" envDefault:""` // PrometheusIncludeLatencyHistogram - set whether Prometheus should also export a histogram of request latencies (this increases cardinality significantly) PrometheusIncludeLatencyHistogram bool `env:"FLAGR_PROMETHEUS_INCLUDE_LATENCY_HISTOGRAM" envDefault:"false"` @@ -201,6 +203,9 @@ var Config = struct { // "HS256" and "RS256" supported JWTAuthSigningMethod string `env:"FLAGR_JWT_AUTH_SIGNING_METHOD" envDefault:"HS256"` + // Verify the JWT through OIDC flows + JWTAuthOIDCWellKnownURL string `env:"FLAGR_JWT_AUTH_OIDC_WELL_KNOWN_URL" envDefault:""` + // Identify users through headers HeaderAuthEnabled bool `env:"FLAGR_HEADER_AUTH_ENABLED" envDefault:"false"` HeaderAuthUserField string `env:"FLAGR_HEADER_AUTH_USER_FIELD" envDefault:"X-Email"` diff --git a/pkg/config/middleware.go b/pkg/config/middleware.go index 55da4ade..39cfeb8d 100644 --- a/pkg/config/middleware.go +++ b/pkg/config/middleware.go @@ -1,8 +1,16 @@ package config import ( + "bytes" + "crypto/rsa" "crypto/subtle" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" "fmt" + "io/ioutil" + "math/big" "net/http" "strconv" "strings" @@ -35,6 +43,8 @@ func ServerShutdown() { func SetupGlobalMiddleware(handler http.Handler) http.Handler { n := negroni.New() + n.Use(&headerCheckMiddleware{}) + if Config.MiddlewareGzipEnabled { n.Use(gzip.Gzip(gzip.DefaultCompression)) } @@ -126,7 +136,183 @@ func setupRecoveryMiddleware() *negroni.Recovery { return r } -/** +type headerCheckMiddleware struct{} + +func (hcm *headerCheckMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + // LOG A STATEMENT HERE + for key, values := range r.Header { + logrus.Infof("Header %s: %s", key, strings.Join(values, ", ")) + } + logrus.Infof("headers::") + logrus.Infof("headers: %#v\n", r.Header) + // logrus.Infof("requestttt") + // logrus.Infof(r) + + if r.Header.Get("Content-Length") != "" && r.Header.Get("Transfer-Encoding") != "" { + http.Error(w, "Invalid request with both Content-Length and Transfer-Encoding headers", http.StatusBadRequest) + return + } + next(w, r) +} + +type OidcConfiguation struct { + JwkURI string `json:"jwks_uri"` // This is the only field we care about for now. +} + +type OidcJwk struct { + JWKeyID string `json:"kid"` // JWTs will have a "kid" that will match one of these + SigningKeys []string `json:"x5c"` // We prefer the x5c since it's easier + Exponent string `json:"e"` // The least preferred way is recreating the key through exponent. + Modulus string `json:"n"` // The modulus +} + +type OidcJwksConfiguration struct { + Jwks []OidcJwk `json:"keys"` +} + +func discoverOidcJwk(tokenKeyID string) *OidcJwk { + // First we need to access the well known url to find out what the jwk_uri is. + resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL) + if err != nil { + logrus.Errorln("Failed to contact OIDC well known URL") + return nil + } + + // Read in the oidc config information, and unmarshal it into our oidc config. + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + logrus.Errorln("Failed to read from OIDC well known URL") + return nil + } + + var oidcConfig OidcConfiguation + json.Unmarshal([]byte(body), &oidcConfig) + + if oidcConfig.JwkURI == "" { + logrus.Errorln("OIDC configuration didn't contain a valid jwks_uri") + return nil + } + + // Now we need to read in the jwks info. + oidcJwksResp, err := http.Get(oidcConfig.JwkURI) + if err != nil { + logrus.Errorln("Failed to contact the JWK URI") + return nil + } + + // Read in the jwks and unmarshall it. + defer oidcJwksResp.Body.Close() + body, err = ioutil.ReadAll(oidcJwksResp.Body) + if err != nil { + logrus.Errorln("Failed to read the JWKs body") + return nil + } + + var oidcJwks OidcJwksConfiguration + + json.Unmarshal([]byte(body), &oidcJwks) + + // Find the key that matches the one in the JWT + if oidcJwks.Jwks == nil { + logrus.Errorln("Missing keys in the jwk config") + return nil + } + + var numKeys = len(oidcJwks.Jwks) + if numKeys == 0 { + logrus.Errorln("No keys in the jwk config") + return nil + } + + matchingKey := -1 + for currentKey := 0; currentKey < numKeys; currentKey++ { + if oidcJwks.Jwks[currentKey].JWKeyID == tokenKeyID { + matchingKey = currentKey + break + } + } + + if matchingKey == -1 { + logrus.Errorln("No matching key found in the jwk config") + return nil + } + + return &oidcJwks.Jwks[matchingKey] +} + +func extractJWTSigningKeyFromX5C(oidcJwk *OidcJwk) (interface{}, error) { + correctKey := oidcJwk.SigningKeys[0] + + // Now we will take the first key and convert it into a cert + // that the library we user are familiar with. This cert requires + // a very specific format that is dependent on whitespace :/ + // There can only be 64 characters on a line or else it won't read it + // so we have to do that manually. Also the BEGIN and END lines need to + // be their own lines. + numCharsInKey := len(correctKey) + var numLines int = numCharsInKey / 64 + if numLines%64 > 0 { + numLines++ + } + + var jwtCert string = "-----BEGIN PUBLIC KEY-----\n" + for currentLine := 0; currentLine < numLines; currentLine++ { + startCharacter := currentLine * 64 + endCharacter := startCharacter + 64 + if startCharacter+64 > numCharsInKey { + endCharacter = numCharsInKey + } + + jwtCert += correctKey[startCharacter:endCharacter] + "\n" + } + jwtCert += "-----END PUBLIC KEY-----" + + return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert)) +} + +func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) { + if jwk.Modulus == "" || jwk.Exponent == "" { + return "", errors.New("Invalid Modulus or Exponent provided") + } + + // Decode the modulus and move it into a big int. + decN, err := base64.RawURLEncoding.DecodeString(jwk.Modulus) + if err != nil { + logrus.Errorf("Failed to decode modulus string.") + return "", err + } + n := big.NewInt(0) + n.SetBytes(decN) + + // Decode the exponent + eStr := jwk.Exponent + decE, err := base64.RawURLEncoding.DecodeString(eStr) + if err != nil { + logrus.Errorf("Failed to decode exponent string") + return "", err + } + + var eBytes []byte + if len(decE) < 8 { + eBytes = make([]byte, 8-len(decE), 8) + eBytes = append(eBytes, decE...) + } else { + eBytes = decE + } + eReader := bytes.NewReader(eBytes) + var e uint64 + err = binary.Read(eReader, binary.BigEndian, &e) + if err != nil { + logrus.Errorf("Failed to read exponent bytes") + return "", err + } + pKey := &rsa.PublicKey{N: n, E: int(e)} + return pKey, nil +} + +/* +* setupJWTAuthMiddleware setup an JWTMiddleware from the ENV config */ func setupJWTAuthMiddleware() *jwtAuth { @@ -149,14 +335,40 @@ func setupJWTAuthMiddleware() *jwtAuth { validationKey = []byte("") } + var validationKeyGetter = func(token *jwt.Token) (interface{}, error) { + return validationKey, errParsingKey + } + + if Config.JWTAuthOIDCWellKnownURL != "" { + validationKeyGetter = func(token *jwt.Token) (interface{}, error) { + + // If this is truly an OIDC token, it should have a "kid" header in the token. + var tokenKeyID = token.Header["kid"] + if tokenKeyID == nil { + return "", errors.New("Missing key id in the JWT") + } + + var jwk = discoverOidcJwk(tokenKeyID.(string)) + if jwk == nil { + return "", errors.New("Failed to find JWK for JWT") + } + + if len(jwk.SigningKeys) > 0 { + return extractJWTSigningKeyFromX5C(jwk) + } else if jwk.Exponent != "" && jwk.Modulus != "" { + return calculateJWTSigningKey(jwk) + } else { + return "", errors.New("JWK is invalid") + } + } + } + return &jwtAuth{ PrefixWhitelistPaths: Config.JWTAuthPrefixWhitelistPaths, ExactWhitelistPaths: Config.JWTAuthExactWhitelistPaths, JWTMiddleware: jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return validationKey, errParsingKey - }, - SigningMethod: signingMethod, + ValidationKeyGetter: validationKeyGetter, + SigningMethod: signingMethod, Extractor: jwtmiddleware.FromFirst( func(r *http.Request) (string, error) { c, err := r.Cookie(Config.JWTAuthCookieTokenName) @@ -220,7 +432,8 @@ func (a *jwtAuth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. a.JWTMiddleware.HandlerWithNext(w, req, next) } -/** +/* +* setupBasicAuthMiddleware setup an BasicMiddleware from the ENV config */ func setupBasicAuthMiddleware() *basicAuth { diff --git a/pkg/handler/eval.go b/pkg/handler/eval.go index 79b4e7fb..a97cfabb 100644 --- a/pkg/handler/eval.go +++ b/pkg/handler/eval.go @@ -127,6 +127,7 @@ var LookupFlag = func(evalContext models.EvalContext) *entity.Flag { flagID := util.SafeUint(evalContext.FlagID) flagKey := util.SafeString(evalContext.FlagKey) f := cache.GetByFlagKeyOrID(flagID) + if f == nil { f = cache.GetByFlagKeyOrID(flagKey) } @@ -253,6 +254,21 @@ var logEvalResultToPrometheus = func(r *models.EvalResult) { util.SafeStringWithDefault(r.VariantKey, "null"), ).Inc() + if r.EvalContext == nil { + config.Global.Prometheus.EvalCounter.WithLabelValues( + util.SafeStringWithDefault("stripped", "null"), + util.SafeStringWithDefault(r.FlagID, "null"), + util.SafeStringWithDefault(r.VariantID, "null"), + util.SafeStringWithDefault(r.VariantKey, "null"), + ).Inc() + } else { + config.Global.Prometheus.EvalCounter.WithLabelValues( + util.SafeStringWithDefault(r.EvalContext.EntityType, "null"), + util.SafeStringWithDefault(r.FlagID, "null"), + util.SafeStringWithDefault(r.VariantID, "null"), + util.SafeStringWithDefault(r.VariantKey, "null"), + ).Inc() + } } var evalSegment = func( diff --git a/pkg/handler/eval_test.go b/pkg/handler/eval_test.go index 0ebb0c33..9bc07383 100644 --- a/pkg/handler/eval_test.go +++ b/pkg/handler/eval_test.go @@ -150,7 +150,7 @@ func TestEvalFlag(t *testing.T) { t.Run("test empty evalContext", func(t *testing.T) { defer gostub.StubFunc(&GetEvalCache, GenFixtureEvalCache()).Reset() - result := EvalFlag(models.EvalContext{FlagID: int64(100)}) + result := EvalFlag(models.EvalContext{FlagID: int64(100)}, false) assert.Zero(t, result.VariantID) assert.NotZero(t, result.FlagID) assert.NotEmpty(t, result.EvalContext.EntityID) @@ -164,7 +164,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotZero(t, result.VariantID) }) @@ -177,7 +177,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagKey: "flag_key_100", - }) + }, false) assert.NotNil(t, result) assert.NotZero(t, result.VariantID) }) @@ -190,7 +190,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagKey: "flag_key_100", - }) + }, false) assert.NotNil(t, result) assert.NotZero(t, result.VariantID) }) @@ -263,7 +263,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -297,7 +297,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -314,7 +314,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.Zero(t, result.VariantID) }) @@ -332,7 +332,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotZero(t, result.VariantID) assert.Equal(t, "entityType1", result.EvalContext.EntityType) @@ -349,7 +349,7 @@ func TestEvalFlag(t *testing.T) { EntityID: "entityID1", EntityType: "entityType1", FlagID: int64(100), - }) + }, false) assert.NotNil(t, result) assert.NotZero(t, result.VariantID) assert.NotEqual(t, "entityType1", result.EvalContext.EntityType) diff --git a/swagger/index.yaml b/swagger/index.yaml index ee26ab93..9c76527b 100644 --- a/swagger/index.yaml +++ b/swagger/index.yaml @@ -538,6 +538,9 @@ definitions: minItems: 1 enableDebug: type: boolean + stripEvalContext: + type: boolean + default: false flagIDs: description: flagIDs type: array diff --git a/swagger_gen/models/evaluation_batch_request.go b/swagger_gen/models/evaluation_batch_request.go index fd780b12..686d0195 100644 --- a/swagger_gen/models/evaluation_batch_request.go +++ b/swagger_gen/models/evaluation_batch_request.go @@ -37,6 +37,8 @@ type EvaluationBatchRequest struct { // Min Items: 1 FlagKeys []string `json:"flagKeys"` + // strip eval context + StripEvalContext *bool `json:"stripEvalContext,omitempty"` // flagTags. Either flagIDs, flagKeys or flagTags works. If pass in multiples, Flagr may return duplicate results. // Min Items: 1 FlagTags []string `json:"flagTags"` diff --git a/swagger_gen/restapi/embedded_spec.go b/swagger_gen/restapi/embedded_spec.go index d3daaf85..5e4cf8cb 100644 --- a/swagger_gen/restapi/embedded_spec.go +++ b/swagger_gen/restapi/embedded_spec.go @@ -1693,6 +1693,10 @@ func init() { "minLength": 1 } }, + "stripEvalContext": { + "type": "boolean", + "default": false + }, "flagTags": { "description": "flagTags. Either flagIDs, flagKeys or flagTags works. If pass in multiples, Flagr may return duplicate results.", "type": "array", @@ -3784,6 +3788,10 @@ func init() { "minLength": 1 } }, + "stripEvalContext": { + "type": "boolean", + "default": false + }, "flagTags": { "description": "flagTags. Either flagIDs, flagKeys or flagTags works. If pass in multiples, Flagr may return duplicate results.", "type": "array",