diff --git a/config/http_config.go b/config/http_config.go index 4dd88758..22902901 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -80,6 +80,9 @@ func (u URL) MarshalYAML() (interface{}, error) { return nil, nil } +// RoundTripWrapper is used in HTTPClientConfig to add custom functionality to the http request path +type RoundTripWrapper func(rt http.RoundTripper) http.RoundTripper + // HTTPClientConfig configures an HTTP client. type HTTPClientConfig struct { // The HTTP basic authentication credentials for the targets. @@ -92,6 +95,8 @@ type HTTPClientConfig struct { ProxyURL URL `yaml:"proxy_url,omitempty"` // TLSConfig to use to connect to the targets. TLSConfig TLSConfig `yaml:"tls_config,omitempty"` + // WrapBaseRoundTripper can be used to add custom functionality in the http request path + WrapBaseRoundTripper RoundTripWrapper `yaml:"-"` } // SetDirectory joins any relative file paths with dir. @@ -186,6 +191,10 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAli } } + if cfg.WrapBaseRoundTripper != nil { + rt = cfg.WrapBaseRoundTripper(rt) + } + // If a bearer token is provided, create a round tripper that will set the // Authorization header correctly on each request. if len(cfg.BearerToken) > 0 { diff --git a/config/http_config_test.go b/config/http_config_test.go index 8596e80b..47438484 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -56,6 +56,8 @@ const ( ExpectedBearer = "Bearer " + BearerToken ExpectedUsername = "arthurdent" ExpectedPassword = "42" + ExpectedHeader = "slartibartfast" + ExpectedHeaderValue = "fjords" ) var invalidHTTPClientConfigs = []struct { @@ -76,6 +78,13 @@ var invalidHTTPClientConfigs = []struct { }, } +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +// RoundTrip implements the RoundTripper interface. +func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return rt(r) +} + func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) { testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler)) @@ -104,6 +113,14 @@ func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httpt } func TestNewClientFromConfig(t *testing.T) { + wrapper := func(rt http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + req.Header.Add(ExpectedHeader, ExpectedHeaderValue) + + return rt.RoundTrip(req) + }) + } + var newClientValidConfig = []struct { clientConfig HTTPClientConfig handler func(w http.ResponseWriter, r *http.Request) @@ -195,6 +212,25 @@ func TestNewClientFromConfig(t *testing.T) { fmt.Fprint(w, ExpectedMessage) } }, + }, { + clientConfig: HTTPClientConfig{ + TLSConfig: TLSConfig{ + CAFile: "", + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: true}, + WrapBaseRoundTripper: wrapper, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + val := r.Header.Get(ExpectedHeader) + if val != ExpectedHeaderValue { + fmt.Fprintf(w, "The expected Header Value (%s) differs from the obtained Header Value (%s)", + ExpectedHeaderValue, val) + } else { + fmt.Fprint(w, ExpectedMessage) + } + }, }, }