diff --git a/resolver.go b/resolver.go index 421104f..a0a38ff 100644 --- a/resolver.go +++ b/resolver.go @@ -2,6 +2,7 @@ package doh import ( "context" + "math" "net" "strings" "sync" @@ -17,8 +18,9 @@ type Resolver struct { url string // RR cache - ipCache map[string]ipAddrEntry - txtCache map[string]txtEntry + ipCache map[string]ipAddrEntry + txtCache map[string]txtEntry + maxCacheTTL time.Duration } type ipAddrEntry struct { @@ -31,16 +33,43 @@ type txtEntry struct { expire time.Time } -func NewResolver(url string) *Resolver { +type Option func(*Resolver) error + +// Specifies the maximum time entries are valid in the cache +// A maxCacheTTL of zero is equivalent to `WithCacheDisabled` +func WithMaxCacheTTL(maxCacheTTL time.Duration) Option { + return func(tr *Resolver) error { + tr.maxCacheTTL = maxCacheTTL + return nil + } +} + +func WithCacheDisabled() Option { + return func(tr *Resolver) error { + tr.maxCacheTTL = 0 + return nil + } +} + +func NewResolver(url string, opts ...Option) (*Resolver, error) { if !strings.HasPrefix(url, "https:") { url = "https://" + url } - return &Resolver{ - url: url, - ipCache: make(map[string]ipAddrEntry), - txtCache: make(map[string]txtEntry), + r := &Resolver{ + url: url, + ipCache: make(map[string]ipAddrEntry), + txtCache: make(map[string]txtEntry), + maxCacheTTL: time.Duration(math.MaxUint32) * time.Second, } + + for _, o := range opts { + if err := o(r); err != nil { + return nil, err + } + } + + return r, nil } var _ madns.BasicResolver = (*Resolver)(nil) @@ -81,7 +110,8 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne } } - r.cacheIPAddr(domain, result, ttl) + cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL) + r.cacheIPAddr(domain, result, cacheTTL) return result, nil } @@ -96,7 +126,8 @@ func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, erro return nil, err } - r.cacheTXT(domain, result, ttl) + cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL) + r.cacheTXT(domain, result, cacheTTL) return result, nil } @@ -118,7 +149,7 @@ func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) { return entry.ips, true } -func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) { +func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duration) { if ttl == 0 { return } @@ -127,7 +158,7 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) { defer r.mx.Unlock() fqdn := dns.Fqdn(domain) - r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)} + r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(ttl)} } func (r *Resolver) getCachedTXT(domain string) ([]string, bool) { @@ -148,7 +179,7 @@ func (r *Resolver) getCachedTXT(domain string) ([]string, bool) { return entry.txt, true } -func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) { +func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) { if ttl == 0 { return } @@ -157,5 +188,12 @@ func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) { defer r.mx.Unlock() fqdn := dns.Fqdn(domain) - r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)} + r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(ttl)} +} + +func minTTL(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b } diff --git a/resolver_test.go b/resolver_test.go index 93584f5..aa21c1c 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/miekg/dns" ) @@ -76,7 +77,10 @@ func TestLookupIPAddr(t *testing.T) { }) defer resolver.Close() - r := NewResolver("") + r, err := NewResolver("https://cloudflare-dns.com/dns-query") + if err != nil { + t.Fatal("resolver cannot be initialised") + } r.url = resolver.URL ips, err := r.LookupIPAddr(context.Background(), domain) @@ -120,7 +124,42 @@ func TestLookupTXT(t *testing.T) { }) defer resolver.Close() - r := NewResolver("") + r, err := NewResolver("") + if err != nil { + t.Fatal("resolver cannot be initialised") + } + r.url = resolver.URL + + txt, err := r.LookupTXT(context.Background(), domain) + if err != nil { + t.Fatal(err) + } + if len(txt) == 0 { + t.Fatal("got no TXT entries") + } + + // check the cache + txt2, ok := r.getCachedTXT(domain) + if !ok { + t.Fatal("expected cache to be populated") + } + if !sameTXT(txt, txt2) { + t.Fatal("expected cache to contain the same txt entries") + } +} + +func TestLookupCache(t *testing.T) { + domain := "example.com" + resolver := mockDoHResolver(t, map[uint16]*dns.Msg{ + dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}), + }) + defer resolver.Close() + + const cacheTTL = time.Second + r, err := NewResolver("", WithMaxCacheTTL(cacheTTL)) + if err != nil { + t.Fatal("resolver cannot be initialised") + } r.url = resolver.URL txt, err := r.LookupTXT(context.Background(), domain) @@ -140,6 +179,15 @@ func TestLookupTXT(t *testing.T) { t.Fatal("expected cache to contain the same txt entries") } + // check cache is empty after its maxTTL + time.Sleep(cacheTTL) + txt2, ok = r.getCachedTXT(domain) + if ok { + t.Fatal("expected cache to be empty") + } + if txt2 != nil { + t.Fatal("expected cache to not contain a txt entry") + } } func sameIPs(a, b []net.IPAddr) bool {