Skip to content

Commit

Permalink
feat: support 'global.tproxy_port_protect' and 'global.so_mark_from_dae'
Browse files Browse the repository at this point in the history
  • Loading branch information
mzz2017 committed May 13, 2023
1 parent f81d3dc commit e9499d5
Show file tree
Hide file tree
Showing 22 changed files with 282 additions and 206 deletions.
36 changes: 35 additions & 1 deletion cmd/run.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package cmd

import (
"context"
"errors"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/protocol/direct"
"math/rand"
"net"
"net/http"
Expand Down Expand Up @@ -247,6 +250,20 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if !conf.Global.DisableWaitingNetwork && len(conf.Subscription) > 0 {
epo := 5 * time.Second
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: epo,
}
log.Infoln("Waiting for network...")
Expand Down Expand Up @@ -274,8 +291,25 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if len(conf.Subscription) > 0 {
log.Infoln("Fetching subscriptions...")
}
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: 30 * time.Second,
}
for _, sub := range conf.Subscription {
tag, nodes, err := subscription.ResolveSubscription(log, filepath.Dir(cfgFile), string(sub))
tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub))
if err != nil {
log.Warnf(`failed to resolve subscription "%v": %v`, sub, err)
resolvingfailed = true
Expand Down
1 change: 1 addition & 0 deletions common/consts/ebpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ var (

const (
TproxyMark uint32 = 0x8000000
Recognize uint16 = 0x2017
LoopbackIfIndex = 1
)

Expand Down
30 changes: 13 additions & 17 deletions common/netutils/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)

Expand Down Expand Up @@ -91,8 +90,8 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil
}

func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, tcp)
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
Expand All @@ -118,16 +117,14 @@ func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, ho
return addrs, nil
}

func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (records []string, err error) {
func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, network string) (records []string, err error) {
typ := dnsmessage.TypeNS
resources, err := resolve(ctx, d, dns, host, typ, tcp)
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
logrus.Println(host, len(resources))
for _, ans := range resources {
if ans.Header.Type != typ {
logrus.Println(host, ans.Header.Type)
continue
}
ns, ok := ans.Body.(*dnsmessage.NSResource)
Expand All @@ -139,7 +136,7 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
return records, nil
}

func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (ans []dnsmessage.Resource, err error) {
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
fqdn := host
Expand Down Expand Up @@ -202,7 +199,11 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
if err != nil {
return nil, err
}
if tcp {
magicNetwork, err := netproxy.ParseMagicNetwork(network)
if err != nil {
return nil, err
}
if magicNetwork.Network == "tcp" {
// Put DNS request length
buf := pool.Get(2 + len(b))
defer pool.Put(buf)
Expand All @@ -213,12 +214,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st

// Dial and write.
cd := &netproxy.ContextDialer{Dialer: d}
var c netproxy.Conn
if tcp {
c, err = cd.DialTcpContext(ctx, dns.String())
} else {
c, err = cd.DialUdpContext(ctx, dns.String())
}
c, err := cd.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
Expand All @@ -228,7 +224,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return nil, err
}
ch := make(chan error, 2)
if !tcp {
if magicNetwork.Network == "udp" {
go func() {
// Resend every 3 seconds for UDP.
for {
Expand All @@ -249,7 +245,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
go func() {
buf := pool.Get(512)
defer pool.Put(buf)
if tcp {
if magicNetwork.Network == "tcp" {
// Read DNS response length
_, err := io.ReadFull(c, buf[:2])
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions common/netutils/ip46.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Ip46 struct {
Ip6 netip.Addr
}

func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool, race bool) (ipv46 *Ip46, err error) {
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, network string, race bool) (ipv46 *Ip46, err error) {
var log *logrus.Logger
if _log := ctx.Value("logger"); _log != nil {
log = _log.(*logrus.Logger)
Expand All @@ -49,7 +49,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
}
}()
var e error
addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, tcp)
addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, network)
if err != nil && !errors.Is(e, context.Canceled) {
err4 = e
return
Expand All @@ -67,7 +67,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
}
}()
var e error
addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, tcp)
addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, network)
if err != nil && !errors.Is(e, context.Canceled) {
err6 = e
return
Expand Down
4 changes: 2 additions & 2 deletions common/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func ResolveFile(u *url.URL, configDir string) (b []byte, err error) {
return bytes.TrimSpace(b), err
}

func ResolveSubscription(log *logrus.Logger, configDir string, subscription string) (tag string, nodes []string, err error) {
func ResolveSubscription(log *logrus.Logger, client *http.Client, configDir string, subscription string) (tag string, nodes []string, err error) {
/// Get tag.
tag, subscription = common.GetTagFromLinkLikePlaintext(subscription)

Expand All @@ -160,7 +160,7 @@ func ResolveSubscription(log *logrus.Logger, configDir string, subscription stri
goto resolve
default:
}
resp, err = http.Get(subscription)
resp, err = client.Get(subscription)
if err != nil {
return "", nil, err
}
Expand Down
32 changes: 22 additions & 10 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"net/netip"
"net/url"
"path/filepath"
Expand Down Expand Up @@ -221,25 +222,25 @@ func FuzzyDecode(to interface{}, val string) bool {
v := reflect.Indirect(reflect.ValueOf(to))
switch v.Kind() {
case reflect.Int:
i, err := strconv.ParseInt(val, 10, strconv.IntSize)
i, err := strconv.ParseInt(val, 0, strconv.IntSize)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int8:
i, err := strconv.ParseInt(val, 10, 8)
i, err := strconv.ParseInt(val, 0, 8)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int16:
i, err := strconv.ParseInt(val, 10, 16)
i, err := strconv.ParseInt(val, 0, 16)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int32:
i, err := strconv.ParseInt(val, 10, 32)
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return false
}
Expand All @@ -253,38 +254,38 @@ func FuzzyDecode(to interface{}, val string) bool {
}
v.Set(reflect.ValueOf(duration))
default:
i, err := strconv.ParseInt(val, 10, 64)
i, err := strconv.ParseInt(val, 0, 64)
if err != nil {
return false
}
v.SetInt(i)
}
case reflect.Uint:
i, err := strconv.ParseUint(val, 10, strconv.IntSize)
i, err := strconv.ParseUint(val, 0, strconv.IntSize)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint8:
i, err := strconv.ParseUint(val, 10, 8)
i, err := strconv.ParseUint(val, 0, 8)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint16:
i, err := strconv.ParseUint(val, 10, 16)
i, err := strconv.ParseUint(val, 0, 16)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint32:
i, err := strconv.ParseUint(val, 10, 32)
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint64:
i, err := strconv.ParseUint(val, 10, 64)
i, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return false
}
Expand Down Expand Up @@ -466,3 +467,14 @@ func IsValidHttpMethod(method string) bool {
return false
}
}

func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {
return netproxy.MagicNetwork{
Network: network,
Mark: mark,
}.Encode()
}
}
13 changes: 9 additions & 4 deletions component/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ type Dns struct {
}

type NewOption struct {
Logger *logrus.Logger
LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
Logger *logrus.Logger
LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
UpstreamResolverNetwork string
}

func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
Expand Down Expand Up @@ -62,7 +63,8 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
}
r := &UpstreamResolver{
Raw: u,
Raw: u,
Network: opt.UpstreamResolverNetwork,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil {
Expand All @@ -77,6 +79,9 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil
}
}(i),
mu: sync.Mutex{},
upstream: nil,
init: false,
}
upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r)
Expand Down
9 changes: 5 additions & 4 deletions component/dns/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type Upstream struct {
*netutils.Ip46
}

func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, fmt.Errorf("%w: %v", FormatError, err)
Expand All @@ -88,7 +88,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err erro
}
}()

ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
Expand Down Expand Up @@ -131,7 +131,8 @@ func (u *Upstream) String() string {
}

type UpstreamResolver struct {
Raw *url.URL
Raw *url.URL
Network string
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex
Expand All @@ -154,7 +155,7 @@ func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
if u.upstream, err = NewUpstream(ctx, u.Raw, u.Network); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %w", err)
}
}
Expand Down
Loading

0 comments on commit e9499d5

Please sign in to comment.