Skip to content

Commit

Permalink
commandconn: don't return error if command closed successfully
Browse files Browse the repository at this point in the history
---
commandconn: fix race on `Close()`

During normal operation, if a `Read()` or `Write()` call results
in an EOF, we call `onEOF()` to handle the terminating command,
and store it's exit value.

However, if a Read/Write call was blocked while `Close()` is called
the in/out pipes are immediately closed which causes an EOF to be
returned. Here, we shouldn't call `onEOF()`, since the reason why
we got an EOF is because we're already terminating the connection.
This also prevents a race between two calls to the commands `Wait()`,
in the `Close()` call and `onEOF()`

---
Add CLI init timeout to SSH connections

---
connhelper: add 30s ssh default dialer timeout

(same as non-ssh dialer)

Signed-off-by: Laura Brehm <[email protected]>
  • Loading branch information
laurazard committed Jun 9, 2023
1 parent 20923df commit a5ebe22
Show file tree
Hide file tree
Showing 9 changed files with 446 additions and 110 deletions.
10 changes: 2 additions & 8 deletions cli/command/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration {

func (cli *DockerCli) initializeFromClient() {
ctx := context.Background()
if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") {
// @FIXME context.WithTimeout doesn't work with connhelper / ssh connections
// time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed"
var cancel func()
ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()

ping, err := cli.client.Ping(ctx)
if err != nil {
Expand Down
208 changes: 106 additions & 102 deletions cli/connhelper/commandconn/commandconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -64,100 +65,86 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {

// commandConn implements net.Conn
type commandConn struct {
cmd *exec.Cmd
cmdExited bool
cmdWaitErr error
cmdMutex sync.Mutex
stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex
stderr bytes.Buffer
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
stdinClosed bool
stdoutClosed bool
localAddr net.Addr
remoteAddr net.Addr
cmdMutex sync.Mutex // for cmd, cmdWaitErr
cmd *exec.Cmd
cmdWaitErr error
cmdExited atomic.Bool
stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex // for stderr
stderr bytes.Buffer
stdinClosed atomic.Bool
stdoutClosed atomic.Bool
closing atomic.Bool
localAddr net.Addr
remoteAddr net.Addr
}

// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
func (c *commandConn) killIfStdioClosed() error {
c.stdioClosedMu.Lock()
stdioClosed := c.stdoutClosed && c.stdinClosed
c.stdioClosedMu.Unlock()
if !stdioClosed {
return nil
// kill terminates the process. On Windows it kills the process directly,
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
// the process after 3 seconds.
func (c *commandConn) kill() {
if c.cmdExited.Load() {
return
}
return c.kill()
}

// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
func killAndWait(cmd *exec.Cmd) error {
c.cmdMutex.Lock()
var werr error
if runtime.GOOS != "windows" {
werrCh := make(chan error)
go func() { werrCh <- cmd.Wait() }()
cmd.Process.Signal(syscall.SIGTERM)
go func() { werrCh <- c.cmd.Wait() }()
_ = c.cmd.Process.Signal(syscall.SIGTERM)
select {
case werr = <-werrCh:
case <-time.After(3 * time.Second):
cmd.Process.Kill()
_ = c.cmd.Process.Kill()
werr = <-werrCh
}
} else {
cmd.Process.Kill()
werr = cmd.Wait()
_ = c.cmd.Process.Kill()
werr = c.cmd.Wait()
}
return werr
c.cmdWaitErr = werr
c.cmdMutex.Unlock()
c.cmdExited.Store(true)
}

// kill returns nil if the command terminated, regardless to the exit status.
func (c *commandConn) kill() error {
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
werr = c.cmdWaitErr
} else {
werr = killAndWait(c.cmd)
c.cmdWaitErr = werr
c.cmdExited = true
}
c.cmdMutex.Unlock()
if werr == nil {
return nil
}
wExitErr, ok := werr.(*exec.ExitError)
if ok {
if wExitErr.ProcessState.Exited() {
return nil
}
// handleEOF handles io.EOF errors while reading or writing from the underlying
// command pipes.
//
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func (c *commandConn) handleEOF(err error) error {
if err != io.EOF {
return err
}
return errors.Wrapf(werr, "commandconn: failed to wait")
}

func (c *commandConn) onEOF(eof error) error {
// when we got EOF, the command is going to be terminated
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
defer c.cmdMutex.Unlock()

var werr error
if c.cmdExited.Load() {
werr = c.cmdWaitErr
} else {
werrCh := make(chan error)
go func() { werrCh <- c.cmd.Wait() }()
select {
case werr = <-werrCh:
c.cmdWaitErr = werr
c.cmdExited = true
c.cmdExited.Store(true)
case <-time.After(10 * time.Second):
c.cmdMutex.Unlock()
c.stderrMu.Lock()
stderr := c.stderr.String()
c.stderrMu.Unlock()
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
}
}
c.cmdMutex.Unlock()

if werr == nil {
return eof
return err
}
c.stderrMu.Lock()
stderr := c.stderr.String()
Expand All @@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error {
}

func ignorableCloseError(err error) bool {
errS := err.Error()
ss := []string{
os.ErrClosed.Error(),
return strings.Contains(err.Error(), os.ErrClosed.Error())
}

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
// check after the call to Read, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
}
for _, s := range ss {
if strings.Contains(errS, s) {
return true
}

return n, c.handleEOF(err)
}

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
// check after the call to Write, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
}
return false

return n, c.handleEOF(err)
}

// CloseRead allows commandConn to implement halfCloser
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseRead: %v", err)
return err
}
c.stdioClosedMu.Lock()
c.stdoutClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseRead: %v", err)
}
return nil
}
c.stdoutClosed.Store(true)

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdinClosed.Load() {
c.kill()
}
return n, err

return nil
}

// CloseWrite allows commandConn to implement halfCloser
func (c *commandConn) CloseWrite() error {
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseWrite: %v", err)
}
c.stdioClosedMu.Lock()
c.stdinClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseWrite: %v", err)
return err
}
return nil
}
c.stdinClosed.Store(true)

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdoutClosed.Load() {
c.kill()
}
return n, err
return nil
}

// Close is the net.Conn func that gets called
// by the transport when a dial is cancelled
// due to it's context timing out. Any blocked
// Read or Write calls will be unblocked and
// return errors. It will block until the underlying
// command has terminated.
func (c *commandConn) Close() error {
var err error
if err = c.CloseRead(); err != nil {
c.closing.Store(true)
defer c.closing.Store(false)

if err := c.CloseRead(); err != nil {
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
return err
}
if err = c.CloseWrite(); err != nil {
if err := c.CloseWrite(); err != nil {
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
return err
}
return err

return nil
}

func (c *commandConn) LocalAddr() net.Addr {
Expand Down
Loading

0 comments on commit a5ebe22

Please sign in to comment.