-
Notifications
You must be signed in to change notification settings - Fork 0
/
resolver.go
148 lines (123 loc) · 3.12 KB
/
resolver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package resolv
import (
"fmt"
"log"
"net"
"sync"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Resolver represents an resolver, an application
// can have multiple resolvers.
type Resolver struct {
}
// NewResolver creates a new resolver with default options.
func NewResolver() *Resolver {
return &Resolver{}
}
// Resolve issue the DNS request returning immediately,
// it returns the response through a channel which is
// closed automatically when the request is finished.
func (r *Resolver) Resolve(req *Request) <-chan *Response {
// Buffered channel to avoid goroutine leaking.
c := make(chan *Response, 1)
// Launch a gorouting which will resolve the request
// and will return the result when it's ready.
go func() {
defer func() {
// Handle unexpected errors from the DNS librariy.
if err := recover(); err != nil {
log.Println("resolv: unexpected error:", err)
c <- &Response{Err: fmt.Errorf("resolve: %v", err)}
}
close(c)
}()
// Prepare the message.
m := new(dns.Msg)
m.Id = dns.Id()
m.RecursionDesired = req.Recurse
m.Question = make([]dns.Question, 1)
m.Question[0] = dns.Question{req.Name, req.Type, req.Class}
cli := new(dns.Client)
cli.Net = req.Mode
// Issue synchronous request.
in, rtt, err := cli.Exchange(m, req.Addr)
if err != nil {
// Timeout.
if nerr, ok := err.(*net.OpError); ok && nerr.Timeout() {
err := NewDNSError("timeout", req)
err.IsTimeout = true
c <- &Response{Err: err}
return
}
// Other error.
c <- &Response{Err: err}
return
}
// Check the RCODE from the message.
if in.Rcode != dns.RcodeSuccess {
err := NewDNSError(
dns.RcodeToString[in.Rcode],
req,
)
if in.Rcode == dns.RcodeNameError {
err.IsNameError = true
}
c <- &Response{Err: err}
return
}
// Handle trruncated messages.
if in.MsgHdr.Truncated {
err := NewDNSError(
"truncated",
req,
)
c <- &Response{Err: err}
return
}
resp := NewResponse(req)
resp.Msg = in
resp.Rtt = rtt
c <- resp
}()
return c
}
// FanIn issues multiple requests and serializes the responses through the returned channel.
func (r *Resolver) FanIn(ctx context.Context, reqs ...*Request) <-chan *Response {
cs := []<-chan *Response{}
for i := 0; i < len(reqs); i++ {
c := r.Resolve(reqs[i])
cs = append(cs, c)
}
return r.merge(ctx, cs...)
}
// merge merges multiple channels into a single channel.
func (r *Resolver) merge(ctx context.Context, cs ...<-chan *Response) <-chan *Response {
var wg sync.WaitGroup
out := make(chan *Response)
// Start an output goroutine for each input channel in cs. output
// copies values from c to out until c or done is closed, then calls
// wg.Done.
output := func(c <-chan *Response) {
defer wg.Done()
for resp := range c {
select {
case out <- resp:
case <-ctx.Done():
out <- &Response{Err: ctx.Err()}
return
}
}
}
wg.Add(len(cs))
for _, c := range cs {
go output(c)
}
// Start a goroutine to close out once all the output goroutines are
// done. This must start after the wg.Add call.
go func() {
wg.Wait()
close(out)
}()
return out
}