From ec6f0d28afd962901fc9a22e2b021f30f089098e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 23 Oct 2024 13:44:08 +0800 Subject: [PATCH] Implement dns-hijack --- cmd/sing-box/cmd_tools.go | 6 +- go.mod | 8 +-- go.sum | 12 ++-- inbound/tun.go | 6 +- outbound/dns.go | 25 +++++--- route/dns.go | 91 +++++++++++++++++++++++++++ route/route.go | 125 +++++++++++++++++++++++++++----------- route/rule/rule_action.go | 72 ++++++++-------------- 8 files changed, 239 insertions(+), 106 deletions(-) create mode 100644 route/dns.go diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go index c45f585576..804863be27 100644 --- a/cmd/sing-box/cmd_tools.go +++ b/cmd/sing-box/cmd_tools.go @@ -1,9 +1,11 @@ package main import ( + "errors" "github.com/sagernet/sing-box" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" + "os" "github.com/spf13/cobra" ) @@ -23,7 +25,9 @@ func init() { func createPreStartedClient() (*box.Box, error) { options, err := readConfigAndMerge() if err != nil { - return nil, err + if !(errors.Is(err, os.ErrNotExist) && len(configDirectories) == 0 && len(configPaths) == 1) || configPaths[0] != "config.json" { + return nil, err + } } instance, err := box.New(box.Options{Options: options}) if err != nil { diff --git a/go.mod b/go.mod index 1a7042f497..40ad0e3c6b 100644 --- a/go.mod +++ b/go.mod @@ -27,14 +27,14 @@ require ( github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/quic-go v0.48.0-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.5.0-rc.4.0.20241022031908-cd17884118cb - github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241021154031-a59e0fbba3ce + github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959 + github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241023053951-feb6d5403f2a github.com/sagernet/sing-mux v0.2.1-0.20241020175909-fe6153f7a9ec github.com/sagernet/sing-quic v0.3.0-rc.1 github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.1.4 - github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241022132441-8ae8c915af9e + github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241023054150-3b5b396d06f7 github.com/sagernet/sing-vmess v0.1.12 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/utls v1.6.7 @@ -55,8 +55,6 @@ require ( howett.net/plist v1.0.1 ) -//replace github.com/sagernet/sing => ../sing - require ( github.com/ajg/form v1.5.1 // indirect github.com/andybalholm/brotli v1.0.6 // indirect diff --git a/go.sum b/go.sum index e048e4605d..78be296340 100644 --- a/go.sum +++ b/go.sum @@ -115,10 +115,10 @@ github.com/sagernet/quic-go v0.48.0-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.5.0-rc.4.0.20241022031908-cd17884118cb h1:3IhGq2UmcbQfAcuqyE8RYKFapqEEa3eItS/MrZr+5l8= -github.com/sagernet/sing v0.5.0-rc.4.0.20241022031908-cd17884118cb/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241021154031-a59e0fbba3ce h1:OfpxE5qnXMyU/9LtNgX4M7bGP11lJx4s+KZ3Sijb0HE= -github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241021154031-a59e0fbba3ce/go.mod h1:TqLIelI+FAbVEdiTRolhGLOwvhVjY7oT+wezlOJUQ7M= +github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959 h1:8BzTt5cU8h6HK4CcRq1UQHKsgUi942GjO0by/ntFZIs= +github.com/sagernet/sing v0.5.0-rc.4.0.20241023053048-94f058276959/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241023053951-feb6d5403f2a h1:jpAlbmZxc1LymZrmJacsvHI57Wito5xy8qASZJMWoOQ= +github.com/sagernet/sing-dns v0.3.0-rc.2.0.20241023053951-feb6d5403f2a/go.mod h1:TqLIelI+FAbVEdiTRolhGLOwvhVjY7oT+wezlOJUQ7M= github.com/sagernet/sing-mux v0.2.1-0.20241020175909-fe6153f7a9ec h1:6Fd/VsEsw9qIjaGi1IBTZSb4b4v5JYtNcoiBtGsQC48= github.com/sagernet/sing-mux v0.2.1-0.20241020175909-fe6153f7a9ec/go.mod h1:RSwqqHwbtTOX3vs6ms8vMtBGH/0ZNyLm/uwt6TlmR84= github.com/sagernet/sing-quic v0.3.0-rc.1 h1:SlzL1yfEAKJyRduub8vzOVtbyTLAX7RZEEBZxO5utts= @@ -129,8 +129,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= -github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241022132441-8ae8c915af9e h1:dKvwQKyNc6/tfwsCU3vlvZo1zwGC9ztsJ3qiQghhBpA= -github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241022132441-8ae8c915af9e/go.mod h1:ZDv85YIANyV7ZTuHx9Vn3dBiEBjuOnebVrL7ccXC9CM= +github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241023054150-3b5b396d06f7 h1:wWfRBSP8v0Gc9yUeMgoKCiG+LIs/+bYLWWwVVYSbGFI= +github.com/sagernet/sing-tun v0.4.0-rc.4.0.20241023054150-3b5b396d06f7/go.mod h1:2v1L3BQKzoOpGuKMwC6pcs/5/Xb5PBqzqL6Lq88IoS8= github.com/sagernet/sing-vmess v0.1.12 h1:2gFD8JJb+eTFMoa8FIVMnknEi+vCSfaiTXTfEYAYAPg= github.com/sagernet/sing-vmess v0.1.12/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= diff --git a/inbound/tun.go b/inbound/tun.go index 588d4b42e3..54fdb66d49 100644 --- a/inbound/tun.go +++ b/inbound/tun.go @@ -35,8 +35,9 @@ type TUN struct { router adapter.Router logger log.ContextLogger // Deprecated - inboundOptions option.InboundOptions - tunOptions tun.Options + inboundOptions option.InboundOptions + tunOptions tun.Options + // Deprecated endpointIndependentNat bool udpTimeout time.Duration stack string @@ -316,7 +317,6 @@ func (t *TUN) Start() error { Context: t.ctx, Tun: tunInterface, TunOptions: t.tunOptions, - EndpointIndependentNat: t.endpointIndependentNat, UDPTimeout: t.udpTimeout, Handler: t, Logger: t.logger, diff --git a/outbound/dns.go b/outbound/dns.go index 08661a99aa..d9c92f19ec 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "net" "os" + "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -50,14 +51,15 @@ func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter metadata.Destination = M.Socksaddr{} defer conn.Close() for { - err := d.handleConnection(ctx, conn, metadata) + conn.SetReadDeadline(time.Now().Add(C.DNSTimeout)) + err := HandleStreamDNSRequest(ctx, d.router, conn, metadata) if err != nil { return err } } } -func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { +func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error { var queryLength uint16 err := binary.Read(conn, binary.BigEndian, &queryLength) if err != nil { @@ -79,7 +81,7 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap } metadataInQuery := metadata go func() error { - response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) if err != nil { return err } @@ -100,10 +102,14 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap // Deprecated func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewDNSPacketConnection(ctx, d.router, conn, nil, metadata) +} + +func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error { metadata.Destination = M.Socksaddr{} var reader N.PacketReader = conn var counters []N.CountFunc - var cachedPackets []*N.PacketBuffer + cachedPackets = common.Reverse(cachedPackets) for { reader, counters = N.UnwrapCountPacketReader(reader, counters) if cachedReader, isCached := reader.(N.CachedPacketReader); isCached { @@ -115,7 +121,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada } if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) - return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) + return newDNSPacketConnection(ctx, router, conn, readWaiter, counters, cachedPackets, metadata) } break } @@ -161,7 +167,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada } metadataInQuery := metadata go func() error { - response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) if err != nil { cancel(err) return err @@ -186,7 +192,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada return group.Run(fastClose) } -func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { +func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { fastClose, cancel := common.ContextWithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group @@ -206,11 +212,12 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa } err = message.Unpack(packet.Buffer.Bytes()) packet.Buffer.Release() + destination = packet.Destination + N.PutPacketBuffer(packet) if err != nil { cancel(err) return err } - destination = packet.Destination } else { buffer, destination, err = readWaiter.WaitReadPacket() if err != nil { @@ -230,7 +237,7 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa } metadataInQuery := metadata go func() error { - response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) if err != nil { cancel(err) return err diff --git a/route/dns.go b/route/dns.go new file mode 100644 index 0000000000..34299ebfd7 --- /dev/null +++ b/route/dns.go @@ -0,0 +1,91 @@ +package route + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/outbound" + "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/udpnat2" + + mDNS "github.com/miekg/dns" +) + +func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + metadata.Destination = M.Socksaddr{} + for { + conn.SetReadDeadline(time.Now().Add(C.DNSTimeout)) + err := outbound.HandleStreamDNSRequest(ctx, r, conn, metadata) + if err != nil { + return err + } + } +} + +func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) { + if uConn, isUDPNAT2 := conn.(*udpnat.Conn); isUDPNAT2 { + metadata.Destination = M.Socksaddr{} + for _, packet := range packetBuffers { + buffer := packet.Buffer + destination := packet.Destination + N.PutPacketBuffer(packet) + go ExchangeDNSPacket(ctx, r, uConn, buffer, metadata, destination) + } + uConn.SetHandler(&dnsHijacker{ + router: r, + conn: conn, + ctx: ctx, + metadata: metadata, + }) + return + } + err := outbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata) + if err != nil && !E.IsClosedOrCanceled(err) { + r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) + } +} + +func ExchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) { + err := exchangeDNSPacket(ctx, router, conn, buffer, metadata, destination) + if err != nil && !E.IsClosedOrCanceled(err) { + router.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) + } +} + +func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext, destination M.Socksaddr) error { + var message mDNS.Msg + err := message.Unpack(buffer.Bytes()) + buffer.Release() + if err != nil { + return E.Cause(err, "unpack request") + } + response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message) + if err != nil { + return err + } + responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024) + if err != nil { + return err + } + err = conn.WritePacket(responseBuffer, destination) + responseBuffer.Release() + return err +} + +type dnsHijacker struct { + router *Router + conn N.PacketConn + ctx context.Context + metadata adapter.InboundContext +} + +func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) { + go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination) +} diff --git a/route/route.go b/route/route.go index 56493bd176..b440efe4f1 100644 --- a/route/route.go +++ b/route/route.go @@ -88,7 +88,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if deadline.NeedAdditionalReadDeadline(conn) { conn = deadline.NewConn(conn) } - selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, conn, nil, -1) + selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1) if err != nil { return err } @@ -109,6 +109,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad buf.ReleaseMulti(buffers) N.CloseOnHandshakeFailure(conn, onClose, action.Error()) return nil + case *rule.RuleActionHijackDNS: + for _, buffer := range buffers { + conn = bufio.NewCachedConn(conn, buffer) + } + r.hijackDNSStream(ctx, conn, metadata) + return nil } } if selectedRule == nil || selectReturn { @@ -226,7 +232,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn)) }*/ - selectedRule, _, buffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1) + selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1) if err != nil { return err } @@ -238,32 +244,35 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m var loaded bool selectedOutbound, loaded = r.Outbound(action.Outbound) if !loaded { - buf.ReleaseMulti(buffers) + N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("outbound not found: ", action.Outbound) } metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping case *rule.RuleActionReturn: selectReturn = true case *rule.RuleActionReject: - buf.ReleaseMulti(buffers) + N.ReleaseMultiPacketBuffer(packetBuffers) N.CloseOnHandshakeFailure(conn, onClose, syscall.ECONNREFUSED) return nil + case *rule.RuleActionHijackDNS: + r.hijackDNSPacket(ctx, conn, packetBuffers, metadata) + return nil } } if selectedRule == nil || selectReturn { if r.defaultOutboundForPacketConnection == nil { - buf.ReleaseMulti(buffers) + N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("missing default outbound with UDP support") } selectedOutbound = r.defaultOutboundForPacketConnection } if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { - buf.ReleaseMulti(buffers) + N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) } - for _, buffer := range buffers { - // TODO: check if metadata.Destination == packet destination - conn = bufio.NewCachedPacketConn(conn, buffer, metadata.Destination) + for _, buffer := range packetBuffers { + conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) + N.PutPacketBuffer(buffer) } if r.clashServer != nil { trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule) @@ -297,7 +306,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m } func (r *Router) PreMatch(metadata adapter.InboundContext) error { - selectedRule, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1) + selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1) if err != nil { return err } @@ -314,7 +323,10 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error { func (r *Router) matchRule( ctx context.Context, metadata *adapter.InboundContext, preMatch bool, inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int, -) (selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, fatalErr error) { +) ( + selectedRule adapter.Rule, selectedRuleIndex int, + buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error, +) { if r.processSearcher != nil && metadata.ProcessInfo == nil { var originDestination netip.AddrPort if metadata.OriginDestination.IsValid() { @@ -376,7 +388,7 @@ func (r *Router) matchRule( //nolint:staticcheck if metadata.InboundOptions != common.DefaultValue[option.InboundOptions]() { if !preMatch && metadata.InboundOptions.SniffEnabled { - newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{ + newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{ OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), }, inputConn, inputPacketConn) @@ -384,7 +396,11 @@ func (r *Router) matchRule( fatalErr = newErr return } - buffers = append(buffers, newBuffers...) + if newBuffer != nil { + buffers = []*buf.Buffer{newBuffer} + } else if len(newPackerBuffers) > 0 { + packetBuffers = newPackerBuffers + } } if dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) != dns.DomainStrategyAsIS { fatalErr = r.actionResolve(ctx, metadata, &rule.RuleActionResolve{ @@ -431,12 +447,16 @@ match: switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: if !preMatch { - newBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) if newErr != nil { fatalErr = newErr return } - buffers = append(buffers, newBuffers...) + if newBuffer != nil { + buffers = append(buffers, newBuffer) + } else if len(newPacketBuffers) > 0 { + packetBuffers = append(packetBuffers, newPacketBuffers...) + } } else { selectedRule = currentRule selectedRuleIndex = currentRuleIndex @@ -455,12 +475,16 @@ match: ruleIndex = currentRuleIndex } if !preMatch && metadata.Destination.Addr.IsUnspecified() { - newBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn) + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn) if newErr != nil { fatalErr = newErr return } - buffers = append(buffers, newBuffers...) + if newBuffer != nil { + buffers = append(buffers, newBuffer) + } else if len(newPacketBuffers) > 0 { + packetBuffers = append(packetBuffers, newPacketBuffers...) + } } return } @@ -468,18 +492,31 @@ match: func (r *Router) actionSniff( ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff, inputConn net.Conn, inputPacketConn N.PacketConn, -) (buffers []*buf.Buffer, fatalErr error) { +) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { if sniff.Skip(metadata) { return - } else if inputConn != nil && len(action.StreamSniffers) > 0 { - buffer := buf.NewPacket() + } else if inputConn != nil { + sniffBuffer := buf.NewPacket() + var streamSniffers []sniff.StreamSniffer + if len(action.StreamSniffers) > 0 { + streamSniffers = action.StreamSniffers + } else { + streamSniffers = []sniff.StreamSniffer{ + sniff.TLSClientHello, + sniff.HTTPHost, + sniff.StreamDomainNameQuery, + sniff.BitTorrent, + sniff.SSH, + sniff.RDP, + } + } err := sniff.PeekStream( ctx, metadata, inputConn, - buffer, + sniffBuffer, action.Timeout, - action.StreamSniffers..., + streamSniffers..., ) if err == nil { //goland:noinspection GoDeprecation @@ -497,15 +534,15 @@ func (r *Router) actionSniff( r.logger.DebugContext(ctx, "sniffed protocol: ", metadata.Protocol) } } - if !buffer.IsEmpty() { - buffers = append(buffers, buffer) + if !sniffBuffer.IsEmpty() { + buffer = sniffBuffer } else { - buffer.Release() + sniffBuffer.Release() } - } else if inputPacketConn != nil && len(action.PacketSniffers) > 0 { + } else if inputPacketConn != nil { for { var ( - buffer = buf.NewPacket() + sniffBuffer = buf.NewPacket() destination M.Socksaddr done = make(chan struct{}) err error @@ -516,7 +553,7 @@ func (r *Router) actionSniff( sniffTimeout = action.Timeout } inputPacketConn.SetReadDeadline(time.Now().Add(sniffTimeout)) - destination, err = inputPacketConn.ReadPacket(buffer) + destination, err = inputPacketConn.ReadPacket(sniffBuffer) inputPacketConn.SetReadDeadline(time.Time{}) close(done) }() @@ -528,7 +565,7 @@ func (r *Router) actionSniff( return } if err != nil { - buffer.Release() + sniffBuffer.Release() if !errors.Is(err, os.ErrDeadlineExceeded) { fatalErr = err return @@ -538,22 +575,40 @@ func (r *Router) actionSniff( if metadata.Destination.Addr.IsUnspecified() { metadata.Destination = destination } - if len(buffers) > 0 { + if len(packetBuffers) > 0 { err = sniff.PeekPacket( ctx, metadata, - buffer.Bytes(), + sniffBuffer.Bytes(), sniff.QUICClientHello, ) } else { + var packetSniffers []sniff.PacketSniffer + if len(action.PacketSniffers) > 0 { + packetSniffers = action.PacketSniffers + } else { + packetSniffers = []sniff.PacketSniffer{ + sniff.DomainNameQuery, + sniff.QUICClientHello, + sniff.STUNMessage, + sniff.UTP, + sniff.UDPTracker, + sniff.DTLSRecord, + } + } err = sniff.PeekPacket( ctx, metadata, - buffer.Bytes(), - action.PacketSniffers..., + sniffBuffer.Bytes(), + packetSniffers..., ) } - buffers = append(buffers, buffer) - if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(buffers) == 0 { + packetBuffer := N.NewPacketBuffer() + *packetBuffer = N.PacketBuffer{ + Buffer: sniffBuffer, + Destination: destination, + } + packetBuffers = append(packetBuffers, packetBuffer) + if E.IsMulti(err, sniff.ErrClientHelloFragmented) && len(packetBuffers) == 0 { r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") continue } diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index a157e94e52..57b7364794 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -162,53 +162,31 @@ func (r *RuleActionSniff) Type() string { } func (r *RuleActionSniff) build() error { - if len(r.StreamSniffers) > 0 || len(r.PacketSniffers) > 0 { - return nil - } - if len(r.snifferNames) > 0 { - for _, name := range r.snifferNames { - switch name { - case C.ProtocolTLS: - r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello) - case C.ProtocolHTTP: - r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost) - case C.ProtocolQUIC: - r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello) - case C.ProtocolDNS: - r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery) - r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery) - case C.ProtocolSTUN: - r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage) - case C.ProtocolBitTorrent: - r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent) - r.PacketSniffers = append(r.PacketSniffers, sniff.UTP) - r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker) - case C.ProtocolDTLS: - r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord) - case C.ProtocolSSH: - r.StreamSniffers = append(r.StreamSniffers, sniff.SSH) - case C.ProtocolRDP: - r.StreamSniffers = append(r.StreamSniffers, sniff.RDP) - default: - return E.New("unknown sniffer: ", name) - } - } - } else { - r.StreamSniffers = []sniff.StreamSniffer{ - sniff.TLSClientHello, - sniff.HTTPHost, - sniff.StreamDomainNameQuery, - sniff.BitTorrent, - sniff.SSH, - sniff.RDP, - } - r.PacketSniffers = []sniff.PacketSniffer{ - sniff.DomainNameQuery, - sniff.QUICClientHello, - sniff.STUNMessage, - sniff.UTP, - sniff.UDPTracker, - sniff.DTLSRecord, + for _, name := range r.snifferNames { + switch name { + case C.ProtocolTLS: + r.StreamSniffers = append(r.StreamSniffers, sniff.TLSClientHello) + case C.ProtocolHTTP: + r.StreamSniffers = append(r.StreamSniffers, sniff.HTTPHost) + case C.ProtocolQUIC: + r.PacketSniffers = append(r.PacketSniffers, sniff.QUICClientHello) + case C.ProtocolDNS: + r.StreamSniffers = append(r.StreamSniffers, sniff.StreamDomainNameQuery) + r.PacketSniffers = append(r.PacketSniffers, sniff.DomainNameQuery) + case C.ProtocolSTUN: + r.PacketSniffers = append(r.PacketSniffers, sniff.STUNMessage) + case C.ProtocolBitTorrent: + r.StreamSniffers = append(r.StreamSniffers, sniff.BitTorrent) + r.PacketSniffers = append(r.PacketSniffers, sniff.UTP) + r.PacketSniffers = append(r.PacketSniffers, sniff.UDPTracker) + case C.ProtocolDTLS: + r.PacketSniffers = append(r.PacketSniffers, sniff.DTLSRecord) + case C.ProtocolSSH: + r.StreamSniffers = append(r.StreamSniffers, sniff.SSH) + case C.ProtocolRDP: + r.StreamSniffers = append(r.StreamSniffers, sniff.RDP) + default: + return E.New("unknown sniffer: ", name) } } return nil