Skip to content

Commit

Permalink
ssh: Extend the Dial context to cover ssh handshake (#54)
Browse files Browse the repository at this point in the history
`ssh.Dial()` took in a context that was used to establish the tcp
connection, however that context doesn't cover the ssh handshake which
can easily block indefinitely. This approximates context support for
ssh.NewClientConn() by having a go routine listen for context
cancellation and closing the connection. We can then check for ctx.Err()
and return that (i.e if the context was canceled).

Note that there is a `Timeout` field in `ssh.ClientConfig` but that also
only covers the TCP connection. See
golang/go#51926

Fixes: #53
  • Loading branch information
nemith authored Jun 15, 2023
1 parent 8ca892b commit f2ecb06
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions transport/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ type Transport struct {
sess *ssh.Session
stdin io.WriteCloser

// indicate that we "own" the client and should close it with the session
// when the transport is closed.
ownedClient bool
// set to true if the transport is managing the underlying ssh connection
// and should close it when the transport is closed. This is is set to true
// when used with `Dial`.
managed bool

*framer
}
Expand All @@ -40,10 +41,36 @@ func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
if err != nil {
return nil, err
}

// Setup a go routine to monitor the context and close the connection. This
// is needed as the underlying ssh library doesn't support contexts so this
// approximates a context based cancelation/timeout for the ssh handshake.
//
// An alternative would be timeout based with conn.SetDeadline(), but then we
// would manage two timeouts. One for tcp connection and one for ssh
// handshake and wouldn't support any other event based cancelation.
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
// context is canceled so close the underlying connection. Will
// will catch ctx.Err() later.
conn.Close()
case <-done:
}
}()

sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
// if there is a context timeout return that error instead of the actual
// error from ssh.NewClientConn.
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, err
}
close(done) // make sure we cleanup the context monitor routine

client := ssh.NewClient(sshConn, chans, reqs)
return newTransport(client, true)
}
Expand All @@ -56,7 +83,7 @@ func NewTransport(client *ssh.Client) (*Transport, error) {
return newTransport(client, false)
}

func newTransport(client *ssh.Client, owned bool) (*Transport, error) {
func newTransport(client *ssh.Client, managed bool) (*Transport, error) {
sess, err := client.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create ssh session: %w", err)
Expand All @@ -78,10 +105,10 @@ func newTransport(client *ssh.Client, owned bool) (*Transport, error) {
}

return &Transport{
c: client,
ownedClient: owned,
sess: sess,
stdin: w,
c: client,
managed: managed,
sess: sess,
stdin: w,

framer: transport.NewFramer(r, w),
}, nil
Expand All @@ -104,7 +131,7 @@ func (t *Transport) Close() error {
retErr = fmt.Errorf("failed to close ssh channel: %w", err)
}

if t.ownedClient {
if t.managed {
if err := t.c.Close(); err != nil {
return fmt.Errorf("failed to close ssh connnection: %w", t.c.Close())
}
Expand Down

0 comments on commit f2ecb06

Please sign in to comment.