diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 989bfa03e3eb..8ab51551b99b 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -111,7 +111,7 @@ func canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig } func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { - errors.LogInfo(ctx, "redirecting request " + dst.String() + " to " + obt) + errors.LogInfo(ctx, "redirecting request "+dst.String()+" to "+obt) h := obm.GetHandler(obt) outbounds := session.OutboundsFromContext(ctx) ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{ @@ -123,10 +123,16 @@ func redirect(ctx context.Context, dst net.Destination, obt string) net.Conn { ur, uw := pipe.New(pipe.OptionsFromContext(ctx)...) dr, dw := pipe.New(pipe.OptionsFromContext(ctx)...) - go h.Dispatch(ctx, &transport.Link{Reader: ur, Writer: dw}) + go h.Dispatch(context.WithoutCancel(ctx), &transport.Link{Reader: ur, Writer: dw}) + var readerOpt cnc.ConnectionOption + if dst.Network == net.Network_TCP { + readerOpt = cnc.ConnectionOutputMulti(dr) + } else { + readerOpt = cnc.ConnectionOutputMultiUDP(dr) + } nc := cnc.NewConnection( cnc.ConnectionInputMulti(uw), - cnc.ConnectionOutputMulti(dr), + readerOpt, cnc.ConnectionOnClose(common.ChainedClosable{uw, dw}), ) return nc @@ -150,7 +156,7 @@ func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig ips, err := lookupIP(dest.Address.String(), sockopt.DomainStrategy, src) if err == nil && len(ips) > 0 { dest.Address = net.IPAddress(ips[dice.Roll(len(ips))]) - errors.LogInfo(ctx, "replace destination with " + dest.String()) + errors.LogInfo(ctx, "replace destination with "+dest.String()) } else if err != nil { errors.LogWarningInner(ctx, err, "failed to resolve ip") } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 45bdc6459c7b..0d487b58d2b6 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -118,7 +118,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in return nil, err } - var udpConn *net.UDPConn + var udpConn net.PacketConn var udpAddr *net.UDPAddr switch c := conn.(type) { @@ -139,7 +139,11 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in return nil, err } default: - return nil, errors.New("unsupported connection type: %T", conn) + udpConn = &internet.FakePacketConn{c} + udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) + if err != nil { + return nil, err + } } return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 73379e4102b7..52b1e830d0b5 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -2,6 +2,7 @@ package internet import ( "context" + "math/rand" "syscall" "time" @@ -48,7 +49,7 @@ func hasBindAddr(sockopt *SocketConfig) bool { } func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { - errors.LogDebug(ctx, "dialing to " + dest.String()) + errors.LogDebug(ctx, "dialing to "+dest.String()) if dest.Network == net.Network_UDP && !hasBindAddr(sockopt) { srcAddr := resolveSrcAddr(net.Network_UDP, src) @@ -221,3 +222,29 @@ func RegisterDialerController(ctl control.Func) error { dialer.controllers = append(dialer.controllers, ctl) return nil } + +type FakePacketConn struct { + net.Conn +} + +func (c *FakePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + return n, c.RemoteAddr(), err +} + +func (c *FakePacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { + return c.Write(p) +} + +func (c *FakePacketConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IP{byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256))}, + Port: rand.Intn(65536), + } +} + +func (c *FakePacketConn) SetReadBuffer(bytes int) error { + // do nothing, this function is only there to suppress quic-go printing + // random warnings about UDP buffers to stdout + return nil +}