Skip to content

Commit

Permalink
Allow FQDN lookup functions to take context (#199)
Browse files Browse the repository at this point in the history
This PR changes the FQDN lookup functions to take a `context.Context`,
thus allowing callers of these functions to pass a context with a
timeout to prevent indefinitely waiting for DNS resolution.
  • Loading branch information
ycombinator authored Feb 1, 2024
1 parent 35e55cd commit 489579d
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 20 deletions.
3 changes: 3 additions & 0 deletions .changelog/199.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
Allow FQDN lookup functions to take a context.
```
7 changes: 6 additions & 1 deletion providers/aix/host_aix_ppc64.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ package aix
import "C"

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -128,8 +129,12 @@ func (*host) Memory() (*types.HostMemoryInfo, error) {
return &mem, nil
}

func (h *host) FQDNWithContext(ctx context.Context) (string, error) {
return shared.FQDNWithContext(ctx)
}

func (h *host) FQDN() (string, error) {
return shared.FQDN()
return h.FQDNWithContext(context.Background())
}

func newHost() (*host, error) {
Expand Down
7 changes: 6 additions & 1 deletion providers/darwin/host_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package darwin

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -139,8 +140,12 @@ func (h *host) Memory() (*types.HostMemoryInfo, error) {
return &mem, nil
}

func (h *host) FQDNWithContext(ctx context.Context) (string, error) {
return shared.FQDNWithContext(ctx)
}

func (h *host) FQDN() (string, error) {
return shared.FQDN()
return h.FQDNWithContext(context.Background())
}

func (h *host) LoadAverage() (*types.LoadAverageInfo, error) {
Expand Down
12 changes: 10 additions & 2 deletions providers/linux/host_fqdn_integration_docker_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
package linux

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand All @@ -32,7 +34,10 @@ func TestHost_FQDN_set(t *testing.T) {
t.Fatal(fmt.Errorf("could not get host information: %w", err))
}

gotFQDN, err := host.FQDN()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

gotFQDN, err := host.FQDNWithContext(ctx)
require.NoError(t, err)
if gotFQDN != wantFQDN {
t.Errorf("got FQDN %q, want: %q", gotFQDN, wantFQDN)
Expand All @@ -45,7 +50,10 @@ func TestHost_FQDN_not_set(t *testing.T) {
t.Fatal(fmt.Errorf("could not get host information: %w", err))
}

gotFQDN, err := host.FQDN()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

gotFQDN, err := host.FQDNWithContext(ctx)
require.NoError(t, err)
hostname := host.Info().Hostname
if gotFQDN != hostname {
Expand Down
7 changes: 6 additions & 1 deletion providers/linux/host_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package linux

import (
"context"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -73,8 +74,12 @@ func (h *host) Memory() (*types.HostMemoryInfo, error) {
return parseMemInfo(content)
}

func (h *host) FQDNWithContext(ctx context.Context) (string, error) {
return shared.FQDNWithContext(ctx)
}

func (h *host) FQDN() (string, error) {
return shared.FQDN()
return h.FQDNWithContext(context.Background())
}

// VMStat reports data from /proc/vmstat on linux.
Expand Down
20 changes: 13 additions & 7 deletions providers/shared/fqdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
package shared

import (
"context"
"fmt"
"net"
"os"
"strings"
)

// FQDN attempts to lookup the host's fully-qualified domain name and returns it.
// FQDNWithContext attempts to lookup the host's fully-qualified domain name and returns it.
// It does so using the following algorithm:
//
// 1. It gets the hostname from the OS. If this step fails, it returns an error.
Expand All @@ -40,18 +41,23 @@ import (
//
// 4. If steps 2 and 3 both fail, an empty string is returned as the FQDN along with
// errors from those steps.
func FQDN() (string, error) {
func FQDNWithContext(ctx context.Context) (string, error) {
hostname, err := os.Hostname()
if err != nil {
return "", fmt.Errorf("could not get hostname to look for FQDN: %w", err)
}

return fqdn(hostname)
return fqdn(ctx, hostname)
}

// FQDN just calls FQDNWithContext with a background context.
func FQDN() (string, error) {
return FQDNWithContext(context.Background())
}

func fqdn(hostname string) (string, error) {
func fqdn(ctx context.Context, hostname string) (string, error) {
var errs error
cname, err := net.LookupCNAME(hostname)
cname, err := net.DefaultResolver.LookupCNAME(ctx, hostname)
if err != nil {
errs = fmt.Errorf("could not get FQDN, all methods failed: failed looking up CNAME: %w",
err)
Expand All @@ -60,13 +66,13 @@ func fqdn(hostname string) (string, error) {
return strings.ToLower(strings.TrimSuffix(cname, ".")), nil
}

ips, err := net.LookupIP(hostname)
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", hostname)
if err != nil {
errs = fmt.Errorf("%s: failed looking up IP: %w", errs, err)
}

for _, ip := range ips {
names, err := net.LookupAddr(ip.String())
names, err := net.DefaultResolver.LookupAddr(ctx, ip.String())
if err != nil || len(names) == 0 {
continue
}
Expand Down
32 changes: 27 additions & 5 deletions providers/shared/fqdn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
package shared

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand All @@ -31,6 +33,7 @@ func TestFQDN(t *testing.T) {
osHostname string
expectedFQDN string
expectedErrRegex string
timeout time.Duration
}{
// This test case depends on network, particularly DNS,
// being available. If it starts to fail often enough
Expand All @@ -44,23 +47,37 @@ func TestFQDN(t *testing.T) {
"long_nonexistent_hostname": {
osHostname: "foo.bar.elastic.co",
expectedFQDN: "",
expectedErrRegex: makeErrorRegex("foo.bar.elastic.co"),
expectedErrRegex: makeErrorRegex("foo.bar.elastic.co", false),
},
"short_nonexistent_hostname": {
osHostname: "foobarbaz",
expectedFQDN: "",
expectedErrRegex: makeErrorRegex("foobarbaz"),
expectedErrRegex: makeErrorRegex("foobarbaz", false),
},
"long_mixed_case_hostname": {
osHostname: "eLaSTic.co",
expectedFQDN: "elastic.co",
expectedErrRegex: "",
},
"nonexistent_timeout": {
osHostname: "foobarbaz",
expectedFQDN: "",
expectedErrRegex: makeErrorRegex("foobarbaz", true),
timeout: 1 * time.Millisecond,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
actualFQDN, err := fqdn(test.osHostname)
timeout := test.timeout
if timeout == 0 {
timeout = 10 * time.Second
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

actualFQDN, err := fqdn(ctx, test.osHostname)
require.Equal(t, test.expectedFQDN, actualFQDN)

if test.expectedErrRegex == "" {
Expand All @@ -72,11 +89,16 @@ func TestFQDN(t *testing.T) {
}
}

func makeErrorRegex(osHostname string) string {
func makeErrorRegex(osHostname string, withTimeout bool) string {
timeoutStr := ""
if withTimeout {
timeoutStr = ": i/o timeout"
}

return fmt.Sprintf(
"could not get FQDN, all methods failed: "+
"failed looking up CNAME: lookup %s.*: "+
"failed looking up IP: lookup %s.*",
"failed looking up IP: lookup %s"+timeoutStr,
osHostname,
osHostname,
)
Expand Down
7 changes: 6 additions & 1 deletion providers/windows/host_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package windows

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -84,7 +85,7 @@ func (h *host) Memory() (*types.HostMemoryInfo, error) {
}, nil
}

func (h *host) FQDN() (string, error) {
func (h *host) FQDNWithContext(_ context.Context) (string, error) {
fqdn, err := getComputerNameEx(stdwindows.ComputerNamePhysicalDnsFullyQualified)
if err != nil {
return "", fmt.Errorf("could not get windows FQDN: %s", err)
Expand All @@ -93,6 +94,10 @@ func (h *host) FQDN() (string, error) {
return strings.ToLower(strings.TrimSuffix(fqdn, ".")), nil
}

func (h *host) FQDN() (string, error) {
return h.FQDNWithContext(context.Background())
}

func newHost() (*host, error) {
h := &host{}
r := &reader{}
Expand Down
11 changes: 9 additions & 2 deletions types/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package types

import "time"
import (
"context"
"time"
)

// Host is the interface that wraps methods for returning Host stats
// It may return partial information if the provider
Expand All @@ -27,7 +30,11 @@ type Host interface {
Info() HostInfo
Memory() (*HostMemoryInfo, error)

// FQDN returns the fully-qualified domain name of the host, lowercased.
// FQDNWithContext returns the fully-qualified domain name of the host, lowercased.
FQDNWithContext(ctx context.Context) (string, error)

// FQDN calls FQDNWithContext with a background context.
// Deprecated: Use FQDNWithContext instead.
FQDN() (string, error)
}

Expand Down

0 comments on commit 489579d

Please sign in to comment.