Skip to content

Commit

Permalink
Dialer: Set TimeoutOnly for gctx and hctx
Browse files Browse the repository at this point in the history
#2232 (comment)

Thank @cty123 for testing

Fixes #2232

BTW: Use `uConn.HandshakeContext(ctx)` in REALITY
  • Loading branch information
RPRX committed Aug 27, 2023
1 parent b24a402 commit d92002a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
12 changes: 7 additions & 5 deletions transport/internet/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,13 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
MinConnectTimeout: 5 * time.Second,
}),
grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))

rawHost, rawPort, err := net.SplitHostPort(s)
select {
case <-gctx.Done():
return nil, gctx.Err()
default:
}

rawHost, rawPort, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
Expand All @@ -119,9 +116,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
return nil, err
}
address := net.ParseAddress(rawHost)

gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
gctx = session.ContextWithTimeoutOnly(gctx, true)

c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
if err == nil && realityConfig != nil {
return reality.UClient(c, realityConfig, ctx, dest)
return reality.UClient(c, realityConfig, gctx, dest)
}
return c, err
}),
Expand Down
12 changes: 6 additions & 6 deletions transport/internet/http/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
}

transport := &http2.Transport{
DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
Expand All @@ -67,18 +67,18 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
}
address := net.ParseAddress(rawHost)

dctx := context.Background()
dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
hctx = session.ContextWithID(hctx, session.IDFromContext(ctx))
hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)

pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}

if realityConfigs != nil {
return reality.UClient(pconn, realityConfigs, ctx, dest)
return reality.UClient(pconn, realityConfigs, hctx, dest)
}

var cn tls.Interface
Expand Down
2 changes: 1 addition & 1 deletion transport/internet/reality/reality.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati
aead.Seal(hello.SessionId[:0], hello.Random[20:], hello.SessionId[:16], hello.Raw)
copy(hello.Raw[39:], hello.SessionId)
}
if err := uConn.Handshake(); err != nil {
if err := uConn.HandshakeContext(ctx); err != nil {
return nil, err
}
if config.Show {
Expand Down

0 comments on commit d92002a

Please sign in to comment.