diff --git a/client.go b/client.go index 93b5b5fe..901afc30 100644 --- a/client.go +++ b/client.go @@ -515,7 +515,7 @@ func (c *Client) Kill() { // If graceful exiting failed, just kill it c.logger.Warn("plugin failed to exit gracefully") - if err := runner.Kill(); err != nil { + if err := runner.Kill(context.Background()); err != nil { c.logger.Debug("error killing plugin", "error", err) } @@ -668,7 +668,9 @@ func (c *Client) Start() (addr net.Addr, err error) { } c.runner = runner - err = runner.Start() + startCtx, startCtxCancel := context.WithTimeout(context.Background(), c.config.StartTimeout) + defer startCtxCancel() + err = runner.Start(startCtx) if err != nil { return nil, err } @@ -678,7 +680,7 @@ func (c *Client) Start() (addr net.Addr, err error) { rErr := recover() if err != nil || rErr != nil { - runner.Kill() + runner.Kill(context.Background()) } if rErr != nil { @@ -707,7 +709,7 @@ func (c *Client) Start() (addr net.Addr, err error) { c.stderrWaitGroup.Wait() // Wait for the command to end. - err := runner.Wait() + err := runner.Wait(context.Background()) if err != nil { c.logger.Error("plugin process exited", "plugin", runner.Name(), "id", runner.ID(), "error", err.Error()) } else { @@ -899,7 +901,7 @@ func (c *Client) reattach() (net.Addr, error) { defer c.ctxCancel() // Wait for the process to die - r.Wait() + r.Wait(context.Background()) // Log so we can see it c.logger.Debug("reattached plugin process exited") diff --git a/client_test.go b/client_test.go index 8def4aa7..2f333c7b 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,7 @@ package plugin import ( "bytes" + "context" "crypto/sha256" "fmt" "io" @@ -226,7 +227,7 @@ func TestClient_grpc_servercrash(t *testing.T) { t.Fatalf("bad: %#v", raw) } - c.runner.Kill() + c.runner.Kill(context.Background()) select { case <-c.doneCtx.Done(): @@ -1255,7 +1256,7 @@ func TestClient_versionedClient(t *testing.T) { t.Fatalf("bad: %#v", raw) } - c.runner.Kill() + c.runner.Kill(context.Background()) select { case <-c.doneCtx.Done(): @@ -1311,7 +1312,7 @@ func TestClient_mtlsClient(t *testing.T) { t.Fatal("invalid response", n) } - c.runner.Kill() + c.runner.Kill(context.Background()) select { case <-c.doneCtx.Done(): @@ -1357,7 +1358,7 @@ func TestClient_mtlsNetRPCClient(t *testing.T) { t.Fatal("invalid response", n) } - c.runner.Kill() + c.runner.Kill(context.Background()) select { case <-c.doneCtx.Done(): diff --git a/internal/cmdrunner/cmd_reattach.go b/internal/cmdrunner/cmd_reattach.go index 74456f0b..dce1a86a 100644 --- a/internal/cmdrunner/cmd_reattach.go +++ b/internal/cmdrunner/cmd_reattach.go @@ -4,6 +4,7 @@ package cmdrunner import ( + "context" "fmt" "net" "os" @@ -49,11 +50,11 @@ type CmdAttachedRunner struct { addrTranslator } -func (c *CmdAttachedRunner) Wait() error { +func (c *CmdAttachedRunner) Wait(_ context.Context) error { return pidWait(c.pid) } -func (c *CmdAttachedRunner) Kill() error { +func (c *CmdAttachedRunner) Kill(_ context.Context) error { return c.process.Kill() } diff --git a/internal/cmdrunner/cmd_runner.go b/internal/cmdrunner/cmd_runner.go index 722e44d0..0f2ff39f 100644 --- a/internal/cmdrunner/cmd_runner.go +++ b/internal/cmdrunner/cmd_runner.go @@ -4,6 +4,7 @@ package cmdrunner import ( + "context" "errors" "fmt" "io" @@ -61,7 +62,7 @@ func NewCmdRunner(logger hclog.Logger, cmd *exec.Cmd) (*CmdRunner, error) { }, nil } -func (c *CmdRunner) Start() error { +func (c *CmdRunner) Start(_ context.Context) error { c.logger.Debug("starting plugin", "path", c.cmd.Path, "args", c.cmd.Args) err := c.cmd.Start() if err != nil { @@ -73,11 +74,11 @@ func (c *CmdRunner) Start() error { return nil } -func (c *CmdRunner) Wait() error { +func (c *CmdRunner) Wait(_ context.Context) error { return c.cmd.Wait() } -func (c *CmdRunner) Kill() error { +func (c *CmdRunner) Kill(_ context.Context) error { if c.cmd.Process != nil { err := c.cmd.Process.Kill() // Swallow ErrProcessDone, we support calling Kill multiple times. diff --git a/runner/runner.go b/runner/runner.go index 86766f16..47e60df0 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -4,6 +4,7 @@ package runner import ( + "context" "io" ) @@ -11,9 +12,11 @@ import ( // of a plugin and attempt to negotiate a connection with it. Note that this // is orthogonal to the protocol and transport used, which is negotiated over stdout. type Runner interface { - // Start should start the plugin and ensure any context required for servicing - // other interface methods is set up. - Start() error + // Start should start the plugin and ensure any work required for servicing + // other interface methods is done. If the context is cancelled, it should + // only abort any attempts to _start_ the plugin. Waiting and shutdown are + // handled separately. + Start(ctx context.Context) error // Stdout is used to negotiate the go-plugin protocol. Stdout() io.ReadCloser @@ -34,10 +37,10 @@ type Runner interface { type AttachedRunner interface { // Wait should wait until the plugin stops running, whether in response to // an out of band signal or in response to calling Kill(). - Wait() error + Wait(ctx context.Context) error // Kill should stop the plugin and perform any cleanup required. - Kill() error + Kill(ctx context.Context) error // ID is a unique identifier to represent the running plugin. e.g. pid or // container ID.