diff --git a/README.md b/README.md index cf7a192..2d72239 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ func main() { // Query Types: dns.TypeA, dns.TypeNS, dns.TypeCNAME, dns.TypeSOA, dns.TypePTR, dns.TypeMX, dns.TypeANY // dns.TypeTXT, dns.TypeAAAA, dns.TypeSRV (from github.com/miekg/dns) + // retryabledns.ErrRetriesExceeded will be returned if a result isn't returned in max retries dnsResponses, err := dnsClient.Query(hostname, dns.TypeA) if err != nil { log.Fatal(err) @@ -69,4 +70,4 @@ func main() { Credits: - `https://github.com/lixiangzhong/dnsutil` -- `https://github.com/rs/dnstrace` \ No newline at end of file +- `https://github.com/rs/dnstrace` diff --git a/client.go b/client.go index 5d74157..6f90dac 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,8 @@ import ( sliceutil "github.com/projectdiscovery/utils/slice" ) +var ErrRetriesExceeded = errors.New("could not resolve, max retries exceeded") + var internalRangeCheckerInstance *internalRangeChecker func init() { @@ -187,7 +189,7 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) { // In case we get a non empty answer stop retrying return resp, nil } - return resp, errors.New("could not resolve, max retries exceeded") + return resp, ErrRetriesExceeded } // Query sends a provided dns request and return enriched response @@ -322,8 +324,9 @@ func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Reso var ( resp *dns.Msg trResp chan *dns.Envelope + i int ) - for i := 0; i < c.options.MaxRetries; i++ { + for i = 0; i < c.options.MaxRetries; i++ { index := atomic.AddUint32(&c.serversIndex, 1) if !hasResolver { resolver = c.resolvers[index%uint32(len(c.resolvers))] @@ -421,6 +424,11 @@ func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Reso break } } + // Finished retry loop at limit, bail out + if i == c.options.MaxRetries { + err = ErrRetriesExceeded + break + } } return &dnsdata, err diff --git a/client_test.go b/client_test.go index 2bed4c3..7b47075 100644 --- a/client_test.go +++ b/client_test.go @@ -123,6 +123,30 @@ func TestQueryMultiple(t *testing.T) { require.NotZero(t, d.TTL) } +func TestRetries(t *testing.T) { + client, _ := New([]string{"127.0.0.1"}, 5) + + // Test that error is returned on max retries, should conn refused 5 times then err + _, err := client.QueryMultiple("scanme.sh", []uint16{dns.TypeA}) + require.True(t, err == ErrRetriesExceeded) + + msg := &dns.Msg{} + msg.Id = dns.Id() + msg.SetEdns0(4096, false) + msg.Question = make([]dns.Question, 1) + msg.RecursionDesired = true + question := dns.Question{ + Name: "scanme.sh", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + msg.Question[0] = question + + // Test with raw Do() interface as well + _, err = client.Do(msg) + require.True(t, err == ErrRetriesExceeded) +} + func TestTrace(t *testing.T) { client, _ := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5)