Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/update with original #560

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_docs/bundle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,9 @@ definitions:
minItems: 1
enableDebug:
type: boolean
stripEvalContext:
type: boolean
default: false
flagIDs:
description: flagIDs
type: array
Expand Down
17 changes: 10 additions & 7 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,21 @@ func setupPrometheus() {
if Config.PrometheusEnabled {
Global.Prometheus.ScrapePath = Config.PrometheusPath
Global.Prometheus.EvalCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "flagr_eval_results",
Help: "A counter of eval results",
}, []string{"EntityType", "FlagID", "FlagKey", "VariantID", "VariantKey"})
Namespace: Config.PrometheusNamespace,
Name: "flagr_eval_results",
Help: "A counter of eval results",
}, []string{"EntityType", "FlagID", "VariantID", "VariantKey"})
Global.Prometheus.RequestCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "flagr_requests_total",
Help: "The total http requests received",
Namespace: Config.PrometheusNamespace,
Name: "flagr_requests_total",
Help: "The total http requests received",
}, []string{"status", "path", "method"})

if Config.PrometheusIncludeLatencyHistogram {
Global.Prometheus.RequestHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "flagr_requests_buckets",
Help: "A histogram of latencies for requests received",
Namespace: Config.PrometheusNamespace,
Name: "flagr_requests_buckets",
Help: "A histogram of latencies for requests received",
}, []string{"status", "path", "method"})
}
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ var Config = struct {
PrometheusEnabled bool `env:"FLAGR_PROMETHEUS_ENABLED" envDefault:"false"`
// PrometheusPath - set the path on which prometheus metrics are available to scrape
PrometheusPath string `env:"FLAGR_PROMETHEUS_PATH" envDefault:"/metrics"`
// PrometheusNamespace - set an optional namespace to prepend to all stats
PrometheusNamespace string `env:"FLAGR_PROMETHEUS_NAMESPACE" envDefault:""`
// PrometheusIncludeLatencyHistogram - set whether Prometheus should also export a histogram of request latencies (this increases cardinality significantly)
PrometheusIncludeLatencyHistogram bool `env:"FLAGR_PROMETHEUS_INCLUDE_LATENCY_HISTOGRAM" envDefault:"false"`

Expand Down Expand Up @@ -201,6 +203,9 @@ var Config = struct {
// "HS256" and "RS256" supported
JWTAuthSigningMethod string `env:"FLAGR_JWT_AUTH_SIGNING_METHOD" envDefault:"HS256"`

// Verify the JWT through OIDC flows
JWTAuthOIDCWellKnownURL string `env:"FLAGR_JWT_AUTH_OIDC_WELL_KNOWN_URL" envDefault:""`

// Identify users through headers
HeaderAuthEnabled bool `env:"FLAGR_HEADER_AUTH_ENABLED" envDefault:"false"`
HeaderAuthUserField string `env:"FLAGR_HEADER_AUTH_USER_FIELD" envDefault:"X-Email"`
Expand Down
225 changes: 219 additions & 6 deletions pkg/config/middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package config

import (
"bytes"
"crypto/rsa"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -35,6 +43,8 @@ func ServerShutdown() {
func SetupGlobalMiddleware(handler http.Handler) http.Handler {
n := negroni.New()

n.Use(&headerCheckMiddleware{})

if Config.MiddlewareGzipEnabled {
n.Use(gzip.Gzip(gzip.DefaultCompression))
}
Expand Down Expand Up @@ -126,7 +136,183 @@ func setupRecoveryMiddleware() *negroni.Recovery {
return r
}

/**
type headerCheckMiddleware struct{}

func (hcm *headerCheckMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// LOG A STATEMENT HERE
for key, values := range r.Header {
logrus.Infof("Header %s: %s", key, strings.Join(values, ", "))
}
logrus.Infof("headers::")
logrus.Infof("headers: %#v\n", r.Header)
// logrus.Infof("requestttt")
// logrus.Infof(r)

if r.Header.Get("Content-Length") != "" && r.Header.Get("Transfer-Encoding") != "" {
http.Error(w, "Invalid request with both Content-Length and Transfer-Encoding headers", http.StatusBadRequest)
return
}
next(w, r)
}

type OidcConfiguation struct {
JwkURI string `json:"jwks_uri"` // This is the only field we care about for now.
}

type OidcJwk struct {
JWKeyID string `json:"kid"` // JWTs will have a "kid" that will match one of these
SigningKeys []string `json:"x5c"` // We prefer the x5c since it's easier
Exponent string `json:"e"` // The least preferred way is recreating the key through exponent.
Modulus string `json:"n"` // The modulus
}

type OidcJwksConfiguration struct {
Jwks []OidcJwk `json:"keys"`
}

func discoverOidcJwk(tokenKeyID string) *OidcJwk {
// First we need to access the well known url to find out what the jwk_uri is.
resp, err := http.Get(Config.JWTAuthOIDCWellKnownURL)
if err != nil {
logrus.Errorln("Failed to contact OIDC well known URL")
return nil
}

// Read in the oidc config information, and unmarshal it into our oidc config.
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
logrus.Errorln("Failed to read from OIDC well known URL")
return nil
}

var oidcConfig OidcConfiguation
json.Unmarshal([]byte(body), &oidcConfig)

if oidcConfig.JwkURI == "" {
logrus.Errorln("OIDC configuration didn't contain a valid jwks_uri")
return nil
}

// Now we need to read in the jwks info.
oidcJwksResp, err := http.Get(oidcConfig.JwkURI)
if err != nil {
logrus.Errorln("Failed to contact the JWK URI")
return nil
}

// Read in the jwks and unmarshall it.
defer oidcJwksResp.Body.Close()
body, err = ioutil.ReadAll(oidcJwksResp.Body)
if err != nil {
logrus.Errorln("Failed to read the JWKs body")
return nil
}

var oidcJwks OidcJwksConfiguration

json.Unmarshal([]byte(body), &oidcJwks)

// Find the key that matches the one in the JWT
if oidcJwks.Jwks == nil {
logrus.Errorln("Missing keys in the jwk config")
return nil
}

var numKeys = len(oidcJwks.Jwks)
if numKeys == 0 {
logrus.Errorln("No keys in the jwk config")
return nil
}

matchingKey := -1
for currentKey := 0; currentKey < numKeys; currentKey++ {
if oidcJwks.Jwks[currentKey].JWKeyID == tokenKeyID {
matchingKey = currentKey
break
}
}

if matchingKey == -1 {
logrus.Errorln("No matching key found in the jwk config")
return nil
}

return &oidcJwks.Jwks[matchingKey]
}

func extractJWTSigningKeyFromX5C(oidcJwk *OidcJwk) (interface{}, error) {
correctKey := oidcJwk.SigningKeys[0]

// Now we will take the first key and convert it into a cert
// that the library we user are familiar with. This cert requires
// a very specific format that is dependent on whitespace :/
// There can only be 64 characters on a line or else it won't read it
// so we have to do that manually. Also the BEGIN and END lines need to
// be their own lines.
numCharsInKey := len(correctKey)
var numLines int = numCharsInKey / 64
if numLines%64 > 0 {
numLines++
}

var jwtCert string = "-----BEGIN PUBLIC KEY-----\n"
for currentLine := 0; currentLine < numLines; currentLine++ {
startCharacter := currentLine * 64
endCharacter := startCharacter + 64
if startCharacter+64 > numCharsInKey {
endCharacter = numCharsInKey
}

jwtCert += correctKey[startCharacter:endCharacter] + "\n"
}
jwtCert += "-----END PUBLIC KEY-----"

return jwt.ParseRSAPublicKeyFromPEM([]byte(jwtCert))
}

func calculateJWTSigningKey(jwk *OidcJwk) (interface{}, error) {
if jwk.Modulus == "" || jwk.Exponent == "" {
return "", errors.New("Invalid Modulus or Exponent provided")
}

// Decode the modulus and move it into a big int.
decN, err := base64.RawURLEncoding.DecodeString(jwk.Modulus)
if err != nil {
logrus.Errorf("Failed to decode modulus string.")
return "", err
}
n := big.NewInt(0)
n.SetBytes(decN)

// Decode the exponent
eStr := jwk.Exponent
decE, err := base64.RawURLEncoding.DecodeString(eStr)
if err != nil {
logrus.Errorf("Failed to decode exponent string")
return "", err
}

var eBytes []byte
if len(decE) < 8 {
eBytes = make([]byte, 8-len(decE), 8)
eBytes = append(eBytes, decE...)
} else {
eBytes = decE
}
eReader := bytes.NewReader(eBytes)
var e uint64
err = binary.Read(eReader, binary.BigEndian, &e)
if err != nil {
logrus.Errorf("Failed to read exponent bytes")
return "", err
}
pKey := &rsa.PublicKey{N: n, E: int(e)}
return pKey, nil
}

/*
*
setupJWTAuthMiddleware setup an JWTMiddleware from the ENV config
*/
func setupJWTAuthMiddleware() *jwtAuth {
Expand All @@ -149,14 +335,40 @@ func setupJWTAuthMiddleware() *jwtAuth {
validationKey = []byte("")
}

var validationKeyGetter = func(token *jwt.Token) (interface{}, error) {
return validationKey, errParsingKey
}

if Config.JWTAuthOIDCWellKnownURL != "" {
validationKeyGetter = func(token *jwt.Token) (interface{}, error) {

// If this is truly an OIDC token, it should have a "kid" header in the token.
var tokenKeyID = token.Header["kid"]
if tokenKeyID == nil {
return "", errors.New("Missing key id in the JWT")
}

var jwk = discoverOidcJwk(tokenKeyID.(string))
if jwk == nil {
return "", errors.New("Failed to find JWK for JWT")
}

if len(jwk.SigningKeys) > 0 {
return extractJWTSigningKeyFromX5C(jwk)
} else if jwk.Exponent != "" && jwk.Modulus != "" {
return calculateJWTSigningKey(jwk)
} else {
return "", errors.New("JWK is invalid")
}
}
}

return &jwtAuth{
PrefixWhitelistPaths: Config.JWTAuthPrefixWhitelistPaths,
ExactWhitelistPaths: Config.JWTAuthExactWhitelistPaths,
JWTMiddleware: jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return validationKey, errParsingKey
},
SigningMethod: signingMethod,
ValidationKeyGetter: validationKeyGetter,
SigningMethod: signingMethod,
Extractor: jwtmiddleware.FromFirst(
func(r *http.Request) (string, error) {
c, err := r.Cookie(Config.JWTAuthCookieTokenName)
Expand Down Expand Up @@ -220,7 +432,8 @@ func (a *jwtAuth) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
a.JWTMiddleware.HandlerWithNext(w, req, next)
}

/**
/*
*
setupBasicAuthMiddleware setup an BasicMiddleware from the ENV config
*/
func setupBasicAuthMiddleware() *basicAuth {
Expand Down
16 changes: 16 additions & 0 deletions pkg/handler/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ var LookupFlag = func(evalContext models.EvalContext) *entity.Flag {
flagID := util.SafeUint(evalContext.FlagID)
flagKey := util.SafeString(evalContext.FlagKey)
f := cache.GetByFlagKeyOrID(flagID)

if f == nil {
f = cache.GetByFlagKeyOrID(flagKey)
}
Expand Down Expand Up @@ -253,6 +254,21 @@ var logEvalResultToPrometheus = func(r *models.EvalResult) {
util.SafeStringWithDefault(r.VariantKey, "null"),
).Inc()

if r.EvalContext == nil {
config.Global.Prometheus.EvalCounter.WithLabelValues(
util.SafeStringWithDefault("stripped", "null"),
util.SafeStringWithDefault(r.FlagID, "null"),
util.SafeStringWithDefault(r.VariantID, "null"),
util.SafeStringWithDefault(r.VariantKey, "null"),
).Inc()
} else {
config.Global.Prometheus.EvalCounter.WithLabelValues(
util.SafeStringWithDefault(r.EvalContext.EntityType, "null"),
util.SafeStringWithDefault(r.FlagID, "null"),
util.SafeStringWithDefault(r.VariantID, "null"),
util.SafeStringWithDefault(r.VariantKey, "null"),
).Inc()
}
}

var evalSegment = func(
Expand Down
Loading