Skip to content

Commit

Permalink
adding support for proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Mzack9999 committed Sep 6, 2024
1 parent 994665e commit 0b895e3
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 43 deletions.
124 changes: 92 additions & 32 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"math/rand"
"net"
"net/url"
"strings"
"sync"
"sync/atomic"
Expand All @@ -20,10 +21,9 @@ import (
iputil "github.com/projectdiscovery/utils/ip"
mapsutil "github.com/projectdiscovery/utils/maps"
sliceutil "github.com/projectdiscovery/utils/slice"
"golang.org/x/net/proxy"
)

var ()

var (
// DefaultMaxPerCNAMEFollows is the default number of times a CNAME can be followed within a trace
DefaultMaxPerCNAMEFollows = 32
Expand Down Expand Up @@ -53,6 +53,9 @@ type Client struct {
tcpClient *dns.Client
dohClient *doh.Client
dotClient *dns.Client
udpProxy proxy.Dialer
tcpProxy proxy.Dialer
dotProxy proxy.Dialer
knownHosts map[string][]string
}

Expand All @@ -76,39 +79,70 @@ func NewWithOptions(options Options) (*Client, error) {
options.MaxPerCNAMEFollows = DefaultMaxPerCNAMEFollows
}

httpClient := doh.NewHttpClientWithTimeout(options.Timeout)
httpClient := doh.NewHttpClient(
doh.WithTimeout(options.Timeout),
doh.WithInsecureSkipVerify(),
doh.WithProxy(options.Proxy), // no-op if empty
)

client := Client{
options: options,
resolvers: parsedBaseResolvers,
udpClient: &dns.Client{
Net: "",
Timeout: options.Timeout,
Dialer: &net.Dialer{
LocalAddr: options.GetLocalAddr(UDP),
},
},
tcpClient: &dns.Client{
Net: TCP.String(),
Timeout: options.Timeout,
Dialer: &net.Dialer{
LocalAddr: options.GetLocalAddr(TCP),
},
},
dohClient: doh.NewWithOptions(
doh.Options{
HttpClient: httpClient,
},
),
dotClient: &dns.Client{
Net: "tcp-tls",
Timeout: options.Timeout,
Dialer: &net.Dialer{
LocalAddr: options.GetLocalAddr(TCP),
},
udpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(UDP)}
tcpDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)}
dotDialer := &net.Dialer{LocalAddr: options.GetLocalAddr(TCP)}

udpClient := &dns.Client{
Net: "",
Timeout: options.Timeout,
Dialer: udpDialer,
}
tcpClient := &dns.Client{
Net: TCP.String(),
Timeout: options.Timeout,
Dialer: tcpDialer,
}
dohClient := doh.NewWithOptions(
doh.Options{
HttpClient: httpClient,
},
)
dotClient := &dns.Client{
Net: "tcp-tls",
Timeout: options.Timeout,
Dialer: dotDialer,
}

client := Client{
options: options,
resolvers: parsedBaseResolvers,
udpClient: udpClient,
tcpClient: tcpClient,
dohClient: dohClient,
dotClient: dotClient,
knownHosts: knownHosts,
}

if options.Proxy != "" {
proxyURL, err := url.Parse(options.Proxy)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %v", err)
}
proxyDialer, err := proxy.FromURL(proxyURL, udpDialer)
if err != nil {
return nil, fmt.Errorf("error creating proxy dialer: %v", err)
}
tcpProxyDialer, err := proxy.FromURL(proxyURL, tcpDialer)
if err != nil {
return nil, fmt.Errorf("error creating proxy dialer: %v", err)
}
dotProxyDialer, err := proxy.FromURL(proxyURL, dotDialer)
if err != nil {
return nil, fmt.Errorf("error creating proxy dialer: %v", err)
}

client.udpProxy = proxyDialer
client.tcpProxy = tcpProxyDialer
client.dotProxy = dotProxyDialer
}

if options.ConnectionPoolThreads > 1 {
client.udpConnPool = mapsutil.SyncLockMap[string, *ConnPool]{
Map: make(mapsutil.Map[string, *ConnPool]),
Expand Down Expand Up @@ -170,12 +204,30 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
if c.tcpProxy != nil {
var tcpConn *dns.Conn
tcpConn, err = c.dialWithProxy(c.tcpProxy, "tcp", resolver.String())
if err != nil {
break
}
defer tcpConn.Close()
resp, _, err = c.tcpClient.ExchangeWithConn(msg, tcpConn)
} else {
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
}
case UDP:
if c.options.ConnectionPoolThreads > 1 {
if udpConnPool, ok := c.udpConnPool.Get(resolver.String()); ok {
resp, _, err = udpConnPool.Exchange(context.TODO(), c.udpClient, msg)
}
} else if c.udpProxy != nil {
var udpConn *dns.Conn
udpConn, err = c.dialWithProxy(c.udpProxy, "udp", resolver.String())
if err != nil {
break
}
defer udpConn.Close()
resp, _, err = c.udpClient.ExchangeWithConn(msg, udpConn)
} else {
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
}
Expand Down Expand Up @@ -204,6 +256,14 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
return resp, ErrRetriesExceeded
}

func (c *Client) dialWithProxy(dialer proxy.Dialer, network, addr string) (*dns.Conn, error) {
conn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
return &dns.Conn{Conn: conn}, nil
}

// Query sends a provided dns request and return enriched response
func (c *Client) Query(host string, requestType uint16) (*DNSData, error) {
return c.QueryMultiple(host, []uint16{requestType})
Expand Down
5 changes: 4 additions & 1 deletion doh/doh_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ func NewWithOptions(options Options) *Client {
}

func New() *Client {
httpClient := NewHttpClientWithTimeout(DefaultTimeout)
httpClient := NewHttpClient(
WithTimeout(DefaultTimeout),
WithInsecureSkipVerify(),
)
return NewWithOptions(Options{DefaultResolver: Cloudflare, HttpClient: httpClient})
}

Expand Down
62 changes: 53 additions & 9 deletions doh/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,61 @@ package doh
import (
"crypto/tls"
"net/http"
"net/url"
"time"
)

func NewHttpClientWithTimeout(timeout time.Duration) *http.Client {
httpClient := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
// ClientOption defines a function type for configuring an http.Client
type ClientOption func(*http.Client)

// WithTimeout sets the timeout for the http.Client
func WithTimeout(timeout time.Duration) ClientOption {
return func(c *http.Client) {
c.Timeout = timeout
}
}

// WithInsecureSkipVerify sets the InsecureSkipVerify option for the TLS config
func WithInsecureSkipVerify() ClientOption {
return func(c *http.Client) {
transport, ok := c.Transport.(*http.Transport)
if !ok {
transport = &http.Transport{}
c.Transport = transport
}
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.InsecureSkipVerify = true
}
}

// WithProxy sets a proxy for the http.Client
func WithProxy(proxyURL string) ClientOption {
return func(c *http.Client) {
if proxyURL == "" {
return
}
proxyURL, err := url.Parse(proxyURL)
if err != nil {
return
}

transport, ok := c.Transport.(*http.Transport)
if !ok {
transport = &http.Transport{}
c.Transport = transport
}

transport.Proxy = http.ProxyURL(proxyURL)
}
}

// NewHttpClient creates a new http.Client with the given options
func NewHttpClient(opts ...ClientOption) *http.Client {
client := &http.Client{}
for _, opt := range opts {
opt(client)
}
return httpClient
return client
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ require (
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/net v0.23.0
golang.org/x/sys v0.18.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Options struct {
LocalAddrPort uint16
ConnectionPoolThreads int
MaxPerCNAMEFollows int
Proxy string
}

// Returns a net.Addr of a UDP or TCP type depending on whats required
Expand Down

0 comments on commit 0b895e3

Please sign in to comment.