Skip to content

Commit

Permalink
Merge pull request #12 from thibmeu/feature/maximum-ttl-cache
Browse files Browse the repository at this point in the history
Add a max TTL for cached entries
  • Loading branch information
marten-seemann authored Oct 12, 2021
2 parents fadcce3 + 909421f commit a5fdce2
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 15 deletions.
64 changes: 51 additions & 13 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package doh

import (
"context"
"math"
"net"
"strings"
"sync"
Expand All @@ -17,8 +18,9 @@ type Resolver struct {
url string

// RR cache
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
ipCache map[string]ipAddrEntry
txtCache map[string]txtEntry
maxCacheTTL time.Duration
}

type ipAddrEntry struct {
Expand All @@ -31,16 +33,43 @@ type txtEntry struct {
expire time.Time
}

func NewResolver(url string) *Resolver {
type Option func(*Resolver) error

// Specifies the maximum time entries are valid in the cache
// A maxCacheTTL of zero is equivalent to `WithCacheDisabled`
func WithMaxCacheTTL(maxCacheTTL time.Duration) Option {
return func(tr *Resolver) error {
tr.maxCacheTTL = maxCacheTTL
return nil
}
}

func WithCacheDisabled() Option {
return func(tr *Resolver) error {
tr.maxCacheTTL = 0
return nil
}
}

func NewResolver(url string, opts ...Option) (*Resolver, error) {
if !strings.HasPrefix(url, "https:") {
url = "https://" + url
}

return &Resolver{
url: url,
ipCache: make(map[string]ipAddrEntry),
txtCache: make(map[string]txtEntry),
r := &Resolver{
url: url,
ipCache: make(map[string]ipAddrEntry),
txtCache: make(map[string]txtEntry),
maxCacheTTL: time.Duration(math.MaxUint32) * time.Second,
}

for _, o := range opts {
if err := o(r); err != nil {
return nil, err
}
}

return r, nil
}

var _ madns.BasicResolver = (*Resolver)(nil)
Expand Down Expand Up @@ -81,7 +110,8 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne
}
}

r.cacheIPAddr(domain, result, ttl)
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
r.cacheIPAddr(domain, result, cacheTTL)
return result, nil
}

Expand All @@ -96,7 +126,8 @@ func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, erro
return nil, err
}

r.cacheTXT(domain, result, ttl)
cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
r.cacheTXT(domain, result, cacheTTL)
return result, nil
}

Expand All @@ -118,7 +149,7 @@ func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
return entry.ips, true
}

func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duration) {
if ttl == 0 {
return
}
Expand All @@ -127,7 +158,7 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl uint32) {
defer r.mx.Unlock()

fqdn := dns.Fqdn(domain)
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(time.Duration(ttl) * time.Second)}
r.ipCache[fqdn] = ipAddrEntry{ips, time.Now().Add(ttl)}
}

func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
Expand All @@ -148,7 +179,7 @@ func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
return entry.txt, true
}

func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) {
if ttl == 0 {
return
}
Expand All @@ -157,5 +188,12 @@ func (r *Resolver) cacheTXT(domain string, txt []string, ttl uint32) {
defer r.mx.Unlock()

fqdn := dns.Fqdn(domain)
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(time.Duration(ttl) * time.Second)}
r.txtCache[fqdn] = txtEntry{txt, time.Now().Add(ttl)}
}

func minTTL(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
52 changes: 50 additions & 2 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/miekg/dns"
)
Expand Down Expand Up @@ -76,7 +77,10 @@ func TestLookupIPAddr(t *testing.T) {
})
defer resolver.Close()

r := NewResolver("")
r, err := NewResolver("https://cloudflare-dns.com/dns-query")
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

ips, err := r.LookupIPAddr(context.Background(), domain)
Expand Down Expand Up @@ -120,7 +124,42 @@ func TestLookupTXT(t *testing.T) {
})
defer resolver.Close()

r := NewResolver("")
r, err := NewResolver("")
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

txt, err := r.LookupTXT(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if len(txt) == 0 {
t.Fatal("got no TXT entries")
}

// check the cache
txt2, ok := r.getCachedTXT(domain)
if !ok {
t.Fatal("expected cache to be populated")
}
if !sameTXT(txt, txt2) {
t.Fatal("expected cache to contain the same txt entries")
}
}

func TestLookupCache(t *testing.T) {
domain := "example.com"
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}),
})
defer resolver.Close()

const cacheTTL = time.Second
r, err := NewResolver("", WithMaxCacheTTL(cacheTTL))
if err != nil {
t.Fatal("resolver cannot be initialised")
}
r.url = resolver.URL

txt, err := r.LookupTXT(context.Background(), domain)
Expand All @@ -140,6 +179,15 @@ func TestLookupTXT(t *testing.T) {
t.Fatal("expected cache to contain the same txt entries")
}

// check cache is empty after its maxTTL
time.Sleep(cacheTTL)
txt2, ok = r.getCachedTXT(domain)
if ok {
t.Fatal("expected cache to be empty")
}
if txt2 != nil {
t.Fatal("expected cache to not contain a txt entry")
}
}

func sameIPs(a, b []net.IPAddr) bool {
Expand Down

0 comments on commit a5fdce2

Please sign in to comment.