diff --git a/client.go b/client.go index e98797f..658bc43 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand" "net" + "net/url" "strings" "sync" "sync/atomic" @@ -20,10 +21,9 @@ import ( iputil "github.com/projectdiscovery/utils/ip" mapsutil "github.com/projectdiscovery/utils/maps" sliceutil "github.com/projectdiscovery/utils/slice" + "golang.org/x/net/proxy" ) -var () - var ( // DefaultMaxPerCNAMEFollows is the default number of times a CNAME can be followed within a trace DefaultMaxPerCNAMEFollows = 32 @@ -53,6 +53,9 @@ type Client struct { tcpClient *dns.Client dohClient *doh.Client dotClient *dns.Client + udpProxy proxy.Dialer + tcpProxy proxy.Dialer + dotProxy proxy.Dialer knownHosts map[string][]string } @@ -76,39 +79,70 @@ func NewWithOptions(options Options) (*Client, error) { options.MaxPerCNAMEFollows = DefaultMaxPerCNAMEFollows } - httpClient := doh.NewHttpClientWithTimeout(options.Timeout) + httpClient := doh.NewHttpClient( + doh.WithTimeout(options.Timeout), + doh.WithInsecureSkipVerify(), + doh.WithProxy(options.Proxy), // no-op if empty + ) - client := Client{ - options: options, - resolvers: parsedBaseResolvers, - udpClient: &dns.Client{ - Net: "", - Timeout: options.Timeout, - Dialer: &net.Dialer{ - LocalAddr: options.GetLocalAddr(UDP), - }, - }, - tcpClient: &dns.Client{ - Net: TCP.String(), - Timeout: options.Timeout, - Dialer: &net.Dialer{ - LocalAddr: options.GetLocalAddr(TCP), - }, - }, - dohClient: doh.NewWithOptions( - doh.Options{ - HttpClient: httpClient, - }, - ), - dotClient: &dns.Client{ - Net: "tcp-tls", - Timeout: options.Timeout, - Dialer: &net.Dialer{ - LocalAddr: options.GetLocalAddr(TCP), - }, + udpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(UDP)} + tcpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)} + dotDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)} + + udpClient := &dns.Client{ + Net: "", + Timeout: options.Timeout, + Dialer: udpDialer, + } + tcpClient := &dns.Client{ + Net: TCP.String(), + Timeout: options.Timeout, + Dialer: tcpDialer, + } + dohClient := doh.NewWithOptions( + doh.Options{ + HttpClient: httpClient, }, + ) + dotClient := &dns.Client{ + Net: "tcp-tls", + Timeout: options.Timeout, + Dialer: dotDialer, + } + + client := Client{ + options: options, + resolvers: parsedBaseResolvers, + udpClient: udpClient, + tcpClient: tcpClient, + dohClient: dohClient, + dotClient: dotClient, knownHosts: knownHosts, } + + if options.Proxy != "" { + proxyURL, err := url.Parse(options.Proxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %v", err) + } + proxyDialer, err := proxy.FromURL(proxyURL, udpDialer) + if err != nil { + return nil, fmt.Errorf("error creating proxy dialer: %v", err) + } + tcpProxyDialer, err := proxy.FromURL(proxyURL, tcpDialer) + if err != nil { + return nil, fmt.Errorf("error creating proxy dialer: %v", err) + } + dotProxyDialer, err := proxy.FromURL(proxyURL, dotDialer) + if err != nil { + return nil, fmt.Errorf("error creating proxy dialer: %v", err) + } + + client.udpProxy = proxyDialer + client.tcpProxy = tcpProxyDialer + client.dotProxy = dotProxyDialer + } + if options.ConnectionPoolThreads > 1 { client.udpConnPool = mapsutil.SyncLockMap[string, *ConnPool]{ Map: make(mapsutil.Map[string, *ConnPool]), @@ -170,12 +204,30 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) { case *NetworkResolver: switch r.Protocol { case TCP: - resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) + if c.tcpProxy != nil { + var tcpConn *dns.Conn + tcpConn, err = c.dialWithProxy(c.tcpProxy, "tcp", resolver.String()) + if err != nil { + break + } + defer tcpConn.Close() + resp, _, err = c.tcpClient.ExchangeWithConn(msg, tcpConn) + } else { + resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) + } case UDP: if c.options.ConnectionPoolThreads > 1 { if udpConnPool, ok := c.udpConnPool.Get(resolver.String()); ok { resp, _, err = udpConnPool.Exchange(context.TODO(), c.udpClient, msg) } + } else if c.udpProxy != nil { + var udpConn *dns.Conn + udpConn, err = c.dialWithProxy(c.udpProxy, "udp", resolver.String()) + if err != nil { + break + } + defer udpConn.Close() + resp, _, err = c.udpClient.ExchangeWithConn(msg, udpConn) } else { resp, _, err = c.udpClient.Exchange(msg, resolver.String()) } @@ -204,6 +256,14 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) { return resp, ErrRetriesExceeded } +func (c *Client) dialWithProxy(dialer proxy.Dialer, network, addr string) (*dns.Conn, error) { + conn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + return &dns.Conn{Conn: conn}, nil +} + // Query sends a provided dns request and return enriched response func (c *Client) Query(host string, requestType uint16) (*DNSData, error) { return c.QueryMultiple(host, []uint16{requestType}) diff --git a/doh/doh_client.go b/doh/doh_client.go index 58a61e2..8f47b12 100644 --- a/doh/doh_client.go +++ b/doh/doh_client.go @@ -21,7 +21,10 @@ func NewWithOptions(options Options) *Client { } func New() *Client { - httpClient := NewHttpClientWithTimeout(DefaultTimeout) + httpClient := NewHttpClient( + WithTimeout(DefaultTimeout), + WithInsecureSkipVerify(), + ) return NewWithOptions(Options{DefaultResolver: Cloudflare, HttpClient: httpClient}) } diff --git a/doh/util.go b/doh/util.go index e23b2c4..db742f6 100644 --- a/doh/util.go +++ b/doh/util.go @@ -3,17 +3,61 @@ package doh import ( "crypto/tls" "net/http" + "net/url" "time" ) -func NewHttpClientWithTimeout(timeout time.Duration) *http.Client { - httpClient := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, +// ClientOption defines a function type for configuring an http.Client +type ClientOption func(*http.Client) + +// WithTimeout sets the timeout for the http.Client +func WithTimeout(timeout time.Duration) ClientOption { + return func(c *http.Client) { + c.Timeout = timeout + } +} + +// WithInsecureSkipVerify sets the InsecureSkipVerify option for the TLS config +func WithInsecureSkipVerify() ClientOption { + return func(c *http.Client) { + transport, ok := c.Transport.(*http.Transport) + if !ok { + transport = &http.Transport{} + c.Transport = transport + } + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.InsecureSkipVerify = true + } +} + +// WithProxy sets a proxy for the http.Client +func WithProxy(proxyURL string) ClientOption { + return func(c *http.Client) { + if proxyURL == "" { + return + } + proxyURL, err := url.Parse(proxyURL) + if err != nil { + return + } + + transport, ok := c.Transport.(*http.Transport) + if !ok { + transport = &http.Transport{} + c.Transport = transport + } + + transport.Proxy = http.ProxyURL(proxyURL) + } +} + +// NewHttpClient creates a new http.Client with the given options +func NewHttpClient(opts ...ClientOption) *http.Client { + client := &http.Client{} + for _, opt := range opts { + opt(client) } - return httpClient + return client } diff --git a/go.mod b/go.mod index 01f05f3..32a94c6 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.23.0 // indirect + golang.org/x/net v0.23.0 golang.org/x/sys v0.18.0 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/options.go b/options.go index 7e0967e..5a56fa8 100644 --- a/options.go +++ b/options.go @@ -21,6 +21,7 @@ type Options struct { LocalAddrPort uint16 ConnectionPoolThreads int MaxPerCNAMEFollows int + Proxy string } // Returns a net.Addr of a UDP or TCP type depending on whats required