From 442a2693ba2711f7222fc63a534ab28d9e24f96c Mon Sep 17 00:00:00 2001 From: Marcin Rataj Date: Tue, 17 Jan 2023 19:02:22 +0100 Subject: [PATCH] feat(routing/http/client): allow custom User-Agent (#31) Closes #17 --- routing/http/client/client.go | 22 ++++++++++++++ routing/http/client/client_test.go | 18 +++++++++++- routing/http/client/transport.go | 41 +++++++++++++++++++++++++++ routing/http/client/transport_test.go | 7 +++++ 4 files changed, 87 insertions(+), 1 deletion(-) diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 225b39f0e..72ab5a102 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -40,6 +40,10 @@ type client struct { afterSignCallback func(req *types.WriteBitswapProviderRecord) } +// defaultUserAgent is used as a fallback to inform HTTP server which library +// version sent a request +var defaultUserAgent = moduleVersion() + var _ contentrouter.Client = &client{} type httpClient interface { @@ -60,6 +64,23 @@ func WithHTTPClient(h httpClient) option { } } +func WithUserAgent(ua string) option { + return func(c *client) { + if ua == "" { + return + } + httpClient, ok := c.httpClient.(*http.Client) + if !ok { + return + } + transport, ok := httpClient.Transport.(*ResponseBodyLimitedTransport) + if !ok { + return + } + transport.UserAgent = ua + } +} + func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { return func(c *client) { c.peerID = peerID @@ -76,6 +97,7 @@ func New(baseURL string, opts ...option) (*client, error) { Transport: &ResponseBodyLimitedTransport{ RoundTripper: http.DefaultTransport, LimitBytes: 1 << 20, + UserAgent: defaultUserAgent, }, } client := &client{ diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 22737b3a9..82e9e3b51 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -48,15 +48,17 @@ type testDeps struct { } func makeTestDeps(t *testing.T) testDeps { + const testUserAgent = "testUserAgent" peerID, addrs, identity := makeProviderAndIdentity() router := &mockContentRouter{} server := httptest.NewServer(server.Handler(router)) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - c, err := New(serverAddr, WithProviderInfo(peerID, addrs), WithIdentity(identity)) + c, err := New(serverAddr, WithProviderInfo(peerID, addrs), WithIdentity(identity), WithUserAgent(testUserAgent)) if err != nil { panic(err) } + assertUserAgentOverride(t, c, testUserAgent) return testDeps{ router: router, server: server, @@ -66,6 +68,20 @@ func makeTestDeps(t *testing.T) testDeps { } } +func assertUserAgentOverride(t *testing.T, c *client, expected string) { + httpClient, ok := c.httpClient.(*http.Client) + if !ok { + t.Error("invalid c.httpClient") + } + transport, ok := httpClient.Transport.(*ResponseBodyLimitedTransport) + if !ok { + t.Error("invalid httpClient.Transport") + } + if transport.UserAgent != expected { + t.Error("invalid httpClient.Transport.UserAgent") + } +} + func makeCID() cid.Cid { buf := make([]byte, 63) _, err := rand.Read(buf) diff --git a/routing/http/client/transport.go b/routing/http/client/transport.go index ea9920463..357d25cb2 100644 --- a/routing/http/client/transport.go +++ b/routing/http/client/transport.go @@ -4,14 +4,21 @@ import ( "fmt" "io" "net/http" + "reflect" + "runtime/debug" + "strings" ) type ResponseBodyLimitedTransport struct { http.RoundTripper LimitBytes int64 + UserAgent string } func (r *ResponseBodyLimitedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if r.UserAgent != "" { + req.Header.Set("User-Agent", r.UserAgent) + } resp, err := r.RoundTripper.RoundTrip(req) if resp != nil && resp.Body != nil { resp.Body = &limitReadCloser{ @@ -36,3 +43,37 @@ func (l *limitReadCloser) Read(p []byte) (int, error) { } return n, err } + +// ImportPath is the canonical import path that allows us to identify +// official client builds vs modified forks, and use that info in User-Agent header. +var ImportPath = importPath() + +// importPath returns the path that library consumers would have in go.mod +func importPath() string { + p := reflect.ValueOf(ResponseBodyLimitedTransport{}).Type().PkgPath() + // we have monorepo, so stripping the remainder + return strings.TrimSuffix(p, "/routing/http/client") +} + +// moduleVersion returns a useful user agent version string allowing us to +// identify requests coming from official releases of this module vs forks. +func moduleVersion() (ua string) { + ua = ImportPath + var module *debug.Module + if bi, ok := debug.ReadBuildInfo(); ok { + // If debug.ReadBuildInfo was successful, we can read Version by finding + // this client in the dependency list of the app that has it in go.mod + for _, dep := range bi.Deps { + if dep.Path == ImportPath { + module = dep + break + } + } + if module != nil { + ua += "@" + module.Version + return + } + ua += "@unknown" + } + return +} diff --git a/routing/http/client/transport_test.go b/routing/http/client/transport_test.go index 9e50a76ed..3db46a99f 100644 --- a/routing/http/client/transport_test.go +++ b/routing/http/client/transport_test.go @@ -75,3 +75,10 @@ func TestResponseBodyLimitedTransport(t *testing.T) { }) } } + +func TestUserAgentVersionString(t *testing.T) { + // forks will have to update below lines to pass test + assert.Equal(t, importPath(), "github.com/ipfs/go-libipfs") + // @unknown because we run in tests + assert.Equal(t, moduleVersion(), "github.com/ipfs/go-libipfs@unknown") +}