Skip to content

Commit

Permalink
Added functional option to allow to customize DialContext() in HTTP c…
Browse files Browse the repository at this point in the history
…lient

Signed-off-by: Marco Pracucci <[email protected]>
  • Loading branch information
pracucci committed Apr 16, 2021
1 parent 4240322 commit 50bd6ae
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
49 changes: 42 additions & 7 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package config

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -194,15 +196,33 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
return unmarshal((*plain)(a))
}

// DialContextFunc defines the signature of the DialContext() function implemented
// by net.Dialer.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)

type httpClientOptions struct {
dialContextFunc DialContextFunc
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
type HTTPClientOption func(options *httpClientOptions)

// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.dialContextFunc = fn
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
}

// NewClientFromConfig returns a new HTTP client configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2)
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2, optFuncs...)
if err != nil {
return nil, err
}
Expand All @@ -217,7 +237,25 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e

// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) {
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
opts := &httpClientOptions{}
for _, f := range optFuncs {
f(opts)
}

var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)

if opts.dialContextFunc != nil {
dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
} else {
dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
}

newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
Expand All @@ -233,10 +271,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAli
IdleConnTimeout: 5 * time.Minute,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
DialContext: dialContext,
}
if enableHTTP2 {
// HTTP/2 support is golang has many problematic cornercases where
Expand Down
21 changes: 21 additions & 0 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package config

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -50,6 +53,7 @@ const (
MissingKey = "missing/secret.key"

ExpectedMessage = "I'm here to serve you!!!"
ExpectedError = "expected error"
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
AuthorizationCredentialsFile = "testdata/bearer.token"
AuthorizationType = "APIKEY"
Expand Down Expand Up @@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
}
}

func TestCustomDialContextFunc(t *testing.T) {
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
return nil, errors.New(ExpectedError)
}

cfg := HTTPClientConfig{}
client, err := NewClientFromConfig(cfg, "test", false, true, WithDialContextFunc(dialFn))
if err != nil {
t.Fatalf("Can't create a client from this config: %+v", cfg)
}

_, err = client.Get("http://localhost")
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
t.Errorf("Expected error %q but got %q", ExpectedError, err)
}
}

func TestMissingBearerAuthFile(t *testing.T) {
cfg := HTTPClientConfig{
BearerTokenFile: MissingBearerTokenFile,
Expand Down

0 comments on commit 50bd6ae

Please sign in to comment.