diff --git a/client.go b/client.go index 6dff73ad..c81258cd 100644 --- a/client.go +++ b/client.go @@ -26,10 +26,12 @@ import ( "time" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin/internal/cmdrunner" + "github.com/hashicorp/go-plugin/runner" "google.golang.org/grpc" ) -const unrecognizedRemotePluginMessage = `Unrecognized remote plugin message: %s +const unrecognizedRemotePluginMessage = `Unrecognized remote plugin message: %q This usually means the plugin was not compiled for this architecture, the plugin is missing dynamic-link libraries necessary to run, @@ -52,7 +54,7 @@ var managedClientsLock sync.Mutex var ( // ErrProcessNotFound is returned when a client is instantiated to // reattach to an existing process and it isn't found. - ErrProcessNotFound = errors.New("Reattachment process not found") + ErrProcessNotFound = cmdrunner.ErrProcessNotFound // ErrChecksumsDoNotMatch is returned when binary's checksum doesn't match // the one provided in the SecureConfig. @@ -87,7 +89,7 @@ type Client struct { exited bool l sync.Mutex address net.Addr - process *os.Process + runner runner.AttachedRunner client ClientProtocol protocol Protocol logger hclog.Logger @@ -106,6 +108,8 @@ type Client struct { // processKilled is used for testing only, to flag when the process was // forcefully killed. processKilled bool + + hostSocketDir string } // NegotiatedVersion returns the protocol version negotiated with the server. @@ -141,6 +145,13 @@ type ClientConfig struct { Cmd *exec.Cmd Reattach *ReattachConfig + // RunnerFunc allows consumers to provide their own implementation of + // runner.Runner and control the context within which a plugin is executed. + // The cmd argument will have been copied from the config and populated with + // environment variables that a go-plugin server expects to read such as + // AutoMTLS certs and the magic cookie key. + RunnerFunc func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) + // SecureConfig is configuration for verifying the integrity of the // executable. It can not be used with Reattach. SecureConfig *SecureConfig @@ -220,6 +231,10 @@ type ClientConfig struct { // to create gRPC connections. This only affects plugins using the gRPC // protocol. GRPCDialOptions []grpc.DialOption + + // SkipHostEnv allows plugins to run without inheriting the parent process' + // environment variables. + SkipHostEnv bool } // ReattachConfig is used to configure a client to reattach to an @@ -231,6 +246,11 @@ type ReattachConfig struct { Addr net.Addr Pid int + // ReattachFunc allows consumers to provide their own implementation of + // runner.AttachedRunner and attach to something other than a plain process. + // At least one of Pid or ReattachFunc must be set. + ReattachFunc runner.ReattachFunc + // Test is set to true if this is reattaching to to a plugin in "test mode" // (see ServeConfig.Test). In this mode, client.Kill will NOT kill the // process and instead will rely on the plugin to terminate itself. This @@ -418,12 +438,13 @@ func (c *Client) killed() bool { func (c *Client) Kill() { // Grab a lock to read some private fields. c.l.Lock() - process := c.process + runner := c.runner addr := c.address + hostSocketDir := c.hostSocketDir c.l.Unlock() - // If there is no process, there is nothing to kill. - if process == nil { + // If there is no runner or ID, there is nothing to kill. + if runner == nil || runner.ID() == "" { return } @@ -431,10 +452,14 @@ func (c *Client) Kill() { // Wait for the all client goroutines to finish. c.clientWaitGroup.Wait() + if hostSocketDir != "" { + os.RemoveAll(hostSocketDir) + } + // Make sure there is no reference to the old process after it has been // killed. c.l.Lock() - c.process = nil + c.runner = nil c.l.Unlock() }() @@ -477,7 +502,9 @@ func (c *Client) Kill() { // If graceful exiting failed, just kill it c.logger.Warn("plugin failed to exit gracefully") - process.Kill() + if err := runner.Kill(); err != nil { + c.logger.Debug("error killing plugin", "error", err) + } c.l.Lock() c.processKilled = true @@ -516,7 +543,7 @@ func (c *Client) Start() (addr net.Addr, err error) { attachSet := c.config.Reattach != nil secureSet := c.config.SecureConfig != nil if cmdSet == attachSet { - return nil, fmt.Errorf("Only one of Cmd or Reattach must be set") + return nil, fmt.Errorf("exactly one of Cmd or Reattach must be set") } if secureSet && attachSet { @@ -555,19 +582,12 @@ func (c *Client) Start() (addr net.Addr, err error) { } cmd := c.config.Cmd - cmd.Env = append(cmd.Env, os.Environ()...) + if !c.config.SkipHostEnv { + cmd.Env = append(cmd.Env, os.Environ()...) + } cmd.Env = append(cmd.Env, env...) cmd.Stdin = os.Stdin - cmdStdout, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - cmdStderr, err := cmd.StderrPipe() - if err != nil { - return nil, err - } - if c.config.SecureConfig != nil { if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil { return nil, fmt.Errorf("error verifying checksum: %s", err) @@ -601,26 +621,42 @@ func (c *Client) Start() (addr net.Addr, err error) { } } - c.logger.Debug("starting plugin", "path", cmd.Path, "args", cmd.Args) - err = cmd.Start() - if err != nil { - return + var runner runner.Runner + switch { + case c.config.RunnerFunc != nil: + c.hostSocketDir, err = os.MkdirTemp("", "") + if err != nil { + return nil, err + } + c.logger.Trace("created temporary directory for unix sockets", "dir", c.hostSocketDir) + runner, err = c.config.RunnerFunc(c.logger, cmd, c.hostSocketDir) + if err != nil { + return nil, err + } + default: + runner, err = cmdrunner.NewCmdRunner(c.logger, cmd) + if err != nil { + return nil, err + } + } - // Set the process - c.process = cmd.Process - c.logger.Debug("plugin started", "path", cmd.Path, "pid", c.process.Pid) + c.runner = runner + err = runner.Start() + if err != nil { + return nil, err + } // Make sure the command is properly cleaned up if there is an error defer func() { - r := recover() + rErr := recover() - if err != nil || r != nil { - cmd.Process.Kill() + if err != nil || rErr != nil { + runner.Kill() } - if r != nil { - panic(r) + if rErr != nil { + panic(rErr) } }() @@ -631,7 +667,7 @@ func (c *Client) Start() (addr net.Addr, err error) { c.clientWaitGroup.Add(1) c.stderrWaitGroup.Add(1) // logStderr calls Done() - go c.logStderr(cmdStderr) + go c.logStderr(runner.Name(), runner.Stderr()) c.clientWaitGroup.Add(1) go func() { @@ -640,29 +676,17 @@ func (c *Client) Start() (addr net.Addr, err error) { defer c.clientWaitGroup.Done() - // get the cmd info early, since the process information will be removed - // in Kill. - pid := c.process.Pid - path := cmd.Path - // wait to finish reading from stderr since the stderr pipe reader // will be closed by the subsequent call to cmd.Wait(). c.stderrWaitGroup.Wait() // Wait for the command to end. - err := cmd.Wait() - - msgArgs := []interface{}{ - "path", path, - "pid", pid, - } + err := runner.Wait() if err != nil { - msgArgs = append(msgArgs, - []interface{}{"error", err.Error()}...) - c.logger.Error("plugin process exited", msgArgs...) + c.logger.Error("plugin process exited", "plugin", runner.Name(), "id", runner.ID(), "error", err.Error()) } else { // Log and make sure to flush the logs right away - c.logger.Info("plugin process exited", msgArgs...) + c.logger.Info("plugin process exited", "plugin", runner.Name(), "id", runner.ID()) } os.Stderr.Sync() @@ -681,10 +705,13 @@ func (c *Client) Start() (addr net.Addr, err error) { defer c.clientWaitGroup.Done() defer close(linesCh) - scanner := bufio.NewScanner(cmdStdout) + scanner := bufio.NewScanner(runner.Stdout()) for scanner.Scan() { linesCh <- scanner.Text() } + if scanner.Err() != nil { + c.logger.Error("error encountered while scanning stdout", "error", scanner.Err()) + } }() // Make sure after we exit we read the lines from stdout forever @@ -751,13 +778,18 @@ func (c *Client) Start() (addr net.Addr, err error) { c.negotiatedVersion = version c.logger.Debug("using plugin", "version", version) - switch parts[2] { + network, address, err := runner.PluginToHost(parts[2], parts[3]) + if err != nil { + return addr, err + } + + switch network { case "tcp": - addr, err = net.ResolveTCPAddr("tcp", parts[3]) + addr, err = net.ResolveTCPAddr("tcp", address) case "unix": - addr, err = net.ResolveUnixAddr("unix", parts[3]) + addr, err = net.ResolveUnixAddr("unix", address) default: - err = fmt.Errorf("Unknown address type: %s", parts[3]) + err = fmt.Errorf("Unknown address type: %s", address) } // If we have a server type, then record that. We default to net/rpc @@ -818,39 +850,30 @@ func (c *Client) loadServerCert(cert string) error { } func (c *Client) reattach() (net.Addr, error) { - // Verify the process still exists. If not, then it is an error - p, err := os.FindProcess(c.config.Reattach.Pid) - if err != nil { - // On Unix systems, FindProcess never returns an error. - // On Windows, for non-existent pids it returns: - // os.SyscallError - 'OpenProcess: the paremter is incorrect' - return nil, ErrProcessNotFound + reattachFunc := c.config.Reattach.ReattachFunc + // For backwards compatibility default to cmdrunner.ReattachFunc + if reattachFunc == nil { + reattachFunc = cmdrunner.ReattachFunc(c.config.Reattach.Pid, c.config.Reattach.Addr) } - // Attempt to connect to the addr since on Unix systems FindProcess - // doesn't actually return an error if it can't find the process. - conn, err := net.Dial( - c.config.Reattach.Addr.Network(), - c.config.Reattach.Addr.String()) + r, err := reattachFunc() if err != nil { - p.Kill() - return nil, ErrProcessNotFound + return nil, err } - conn.Close() // Create a context for when we kill c.doneCtx, c.ctxCancel = context.WithCancel(context.Background()) c.clientWaitGroup.Add(1) // Goroutine to mark exit status - go func(pid int) { + go func(r runner.AttachedRunner) { defer c.clientWaitGroup.Done() // ensure the context is cancelled when we're done defer c.ctxCancel() // Wait for the process to die - pidWait(pid) + r.Wait() // Log so we can see it c.logger.Debug("reattached plugin process exited") @@ -859,7 +882,7 @@ func (c *Client) reattach() (net.Addr, error) { c.l.Lock() defer c.l.Unlock() c.exited = true - }(p.Pid) + }(r) // Set the address and protocol c.address = c.config.Reattach.Addr @@ -877,7 +900,7 @@ func (c *Client) reattach() (net.Addr, error) { // process being killed (the only purpose we have for c.process), since // in test mode the process is responsible for exiting on its own. if !c.config.Reattach.Test { - c.process = p + c.runner = r } return c.address, nil @@ -989,10 +1012,10 @@ func (c *Client) dialer(_ string, timeout time.Duration) (net.Conn, error) { var stdErrBufferSize = 64 * 1024 -func (c *Client) logStderr(r io.Reader) { +func (c *Client) logStderr(name string, r io.Reader) { defer c.clientWaitGroup.Done() defer c.stderrWaitGroup.Done() - l := c.logger.Named(filepath.Base(c.config.Cmd.Path)) + l := c.logger.Named(filepath.Base(name)) reader := bufio.NewReaderSize(r, stdErrBufferSize) // continuation indicates the previous line was a prefix diff --git a/client_test.go b/client_test.go index 38dadfe3..cde083c8 100644 --- a/client_test.go +++ b/client_test.go @@ -21,6 +21,8 @@ import ( "time" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin/internal/cmdrunner" + "github.com/hashicorp/go-plugin/runner" ) func TestClient(t *testing.T) { @@ -224,7 +226,7 @@ func TestClient_grpc_servercrash(t *testing.T) { t.Fatalf("bad: %#v", raw) } - c.process.Kill() + c.runner.Kill() select { case <-c.doneCtx.Done(): @@ -301,17 +303,37 @@ func TestClient_grpcNotAllowed(t *testing.T) { } func TestClient_grpcSyncStdio(t *testing.T) { + for name, tc := range map[string]struct { + useRunnerFunc bool + }{ + "default": {false}, + "use RunnerFunc": {true}, + } { + t.Run(name, func(t *testing.T) { + testClient_grpcSyncStdio(t, tc.useRunnerFunc) + }) + } +} + +func testClient_grpcSyncStdio(t *testing.T, useRunnerFunc bool) { var syncOut, syncErr safeBuffer process := helperProcess("test-grpc") - c := NewClient(&ClientConfig{ + cfg := &ClientConfig{ Cmd: process, HandshakeConfig: testHandshake, Plugins: testGRPCPluginMap, AllowedProtocols: []Protocol{ProtocolGRPC}, SyncStdout: &syncOut, SyncStderr: &syncErr, - }) + } + + if useRunnerFunc { + cfg.RunnerFunc = func(l hclog.Logger, cmd *exec.Cmd, _ string) (runner.Runner, error) { + return cmdrunner.NewCmdRunner(l, cmd) + } + } + c := NewClient(cfg) defer c.Kill() if _, err := c.Start(); err != nil { @@ -495,6 +517,19 @@ func TestClient_reattachNoProtocol(t *testing.T) { } func TestClient_reattachGRPC(t *testing.T) { + for name, tc := range map[string]struct { + useReattachFunc bool + }{ + "default": {false}, + "use ReattachFunc": {true}, + } { + t.Run(name, func(t *testing.T) { + testClient_reattachGRPC(t, tc.useReattachFunc) + }) + } +} + +func testClient_reattachGRPC(t *testing.T, useReattachFunc bool) { process := helperProcess("test-grpc") c := NewClient(&ClientConfig{ Cmd: process, @@ -513,6 +548,12 @@ func TestClient_reattachGRPC(t *testing.T) { // Get the reattach configuration reattach := c.ReattachConfig() + if useReattachFunc { + pid := reattach.Pid + reattach.Pid = 0 + reattach.ReattachFunc = cmdrunner.ReattachFunc(pid, reattach.Addr) + } + // Create a new client c = NewClient(&ClientConfig{ Reattach: reattach, @@ -584,7 +625,6 @@ func TestClient_reattachNotFound(t *testing.T) { Plugins: testPluginMap, }) - // Start shouldn't error if _, err := c.Start(); err == nil { t.Fatal("should error") } else if err != ErrProcessNotFound { @@ -820,6 +860,45 @@ func TestClient_Stdin(t *testing.T) { } } +func TestClient_SkipHostEnv(t *testing.T) { + for _, tc := range []struct { + helper string + skip bool + }{ + {"test-skip-host-env-true", true}, + {"test-skip-host-env-false", false}, + } { + t.Run(tc.helper, func(t *testing.T) { + process := helperProcess(tc.helper) + t.Setenv("PLUGIN_TEST_SKIP_HOST_ENV", "foo") + c := NewClient(&ClientConfig{ + Cmd: process, + HandshakeConfig: testHandshake, + Plugins: testPluginMap, + SkipHostEnv: tc.skip, + }) + defer c.Kill() + + _, err := c.Start() + if err != nil { + t.Fatalf("error: %s", err) + } + + for { + if c.Exited() { + break + } + + time.Sleep(50 * time.Millisecond) + } + + if !process.ProcessState.Success() { + t.Fatal("process didn't exit cleanly") + } + }) + } +} + func TestClient_SecureConfig(t *testing.T) { // Test failure case secureConfig := &SecureConfig{ @@ -1161,7 +1240,7 @@ func TestClient_versionedClient(t *testing.T) { t.Fatalf("bad: %#v", raw) } - c.process.Kill() + c.runner.Kill() select { case <-c.doneCtx.Done(): @@ -1217,7 +1296,7 @@ func TestClient_mtlsClient(t *testing.T) { t.Fatal("invalid response", n) } - c.process.Kill() + c.runner.Kill() select { case <-c.doneCtx.Done(): @@ -1263,7 +1342,7 @@ func TestClient_mtlsNetRPCClient(t *testing.T) { t.Fatal("invalid response", n) } - c.process.Kill() + c.runner.Kill() select { case <-c.doneCtx.Done(): @@ -1392,7 +1471,7 @@ this line is short reader := strings.NewReader(msg) c.stderrWaitGroup.Add(1) - c.logStderr(reader) + c.logStderr(c.config.Cmd.Path, reader) read := stderr.String() if read != msg { diff --git a/constants.go b/constants.go new file mode 100644 index 00000000..b66fa799 --- /dev/null +++ b/constants.go @@ -0,0 +1,9 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package plugin + +const ( + EnvUnixSocketDir = "PLUGIN_UNIX_SOCKET_DIR" + EnvUnixSocketGroup = "PLUGIN_UNIX_SOCKET_GROUP" +) diff --git a/grpc_broker.go b/grpc_broker.go index 9bf56776..91eee6e6 100644 --- a/grpc_broker.go +++ b/grpc_broker.go @@ -15,6 +15,7 @@ import ( "time" "github.com/hashicorp/go-plugin/internal/plugin" + "github.com/hashicorp/go-plugin/runner" "github.com/oklog/run" "google.golang.org/grpc" @@ -267,6 +268,9 @@ type GRPCBroker struct { doneCh chan struct{} o sync.Once + socketDir string + addrTranslator runner.AddrTranslator + sync.Mutex } @@ -275,12 +279,15 @@ type gRPCBrokerPending struct { doneCh chan struct{} } -func newGRPCBroker(s streamer, tls *tls.Config) *GRPCBroker { +func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator runner.AddrTranslator) *GRPCBroker { return &GRPCBroker{ streamer: s, streams: make(map[uint32]*gRPCBrokerPending), tls: tls, doneCh: make(chan struct{}), + + socketDir: socketDir, + addrTranslator: addrTranslator, } } @@ -288,15 +295,23 @@ func newGRPCBroker(s streamer, tls *tls.Config) *GRPCBroker { // // This should not be called multiple times with the same ID at one time. func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) { - listener, err := serverListener() + listener, err := serverListener(b.socketDir) if err != nil { return nil, err } + advertiseNet := listener.Addr().Network() + advertiseAddr := listener.Addr().String() + if b.addrTranslator != nil { + advertiseNet, advertiseAddr, err = b.addrTranslator.HostToPlugin(advertiseNet, advertiseAddr) + if err != nil { + return nil, err + } + } err = b.streamer.Send(&plugin.ConnInfo{ ServiceId: id, - Network: listener.Addr().Network(), - Address: listener.Addr().String(), + Network: advertiseNet, + Address: advertiseAddr, }) if err != nil { return nil, err @@ -379,12 +394,20 @@ func (b *GRPCBroker) Dial(id uint32) (conn *grpc.ClientConn, err error) { return nil, fmt.Errorf("timeout waiting for connection info") } + network, address := c.Network, c.Address + if b.addrTranslator != nil { + network, address, err = b.addrTranslator.PluginToHost(network, address) + if err != nil { + return nil, err + } + } + var addr net.Addr - switch c.Network { + switch network { case "tcp": - addr, err = net.ResolveTCPAddr("tcp", c.Address) + addr, err = net.ResolveTCPAddr("tcp", address) case "unix": - addr, err = net.ResolveUnixAddr("unix", c.Address) + addr, err = net.ResolveUnixAddr("unix", address) default: err = fmt.Errorf("Unknown address type: %s", c.Address) } diff --git a/grpc_client.go b/grpc_client.go index 6454d426..f11dd0da 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -63,7 +63,7 @@ func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) { // Start the broker. brokerGRPCClient := newGRPCBrokerClient(conn) - broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig) + broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.hostSocketDir, c.runner) go broker.Run() go brokerGRPCClient.StartStream() diff --git a/grpc_server.go b/grpc_server.go index 7203a2cf..303d650a 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -84,7 +84,7 @@ func (s *GRPCServer) Init() error { // Register the broker service brokerServer := newGRPCBrokerServer() plugin.RegisterGRPCBrokerServer(s.server, brokerServer) - s.broker = newGRPCBroker(brokerServer, s.TLS) + s.broker = newGRPCBroker(brokerServer, s.TLS, "", nil) go s.broker.Run() // Register the controller diff --git a/internal/cmdrunner/addr_translator.go b/internal/cmdrunner/addr_translator.go new file mode 100644 index 00000000..1854d2dd --- /dev/null +++ b/internal/cmdrunner/addr_translator.go @@ -0,0 +1,16 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package cmdrunner + +// addrTranslator implements stateless identity functions, as the host and plugin +// run in the same context wrt Unix and network addresses. +type addrTranslator struct{} + +func (*addrTranslator) PluginToHost(pluginNet, pluginAddr string) (string, string, error) { + return pluginNet, pluginAddr, nil +} + +func (*addrTranslator) HostToPlugin(hostNet, hostAddr string) (string, string, error) { + return hostNet, hostAddr, nil +} diff --git a/internal/cmdrunner/cmd_reattach.go b/internal/cmdrunner/cmd_reattach.go new file mode 100644 index 00000000..74456f0b --- /dev/null +++ b/internal/cmdrunner/cmd_reattach.go @@ -0,0 +1,62 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package cmdrunner + +import ( + "fmt" + "net" + "os" + + "github.com/hashicorp/go-plugin/runner" +) + +// ReattachFunc returns a function that allows reattaching to a plugin running +// as a plain process. The process may or may not be a child process. +func ReattachFunc(pid int, addr net.Addr) runner.ReattachFunc { + return func() (runner.AttachedRunner, error) { + p, err := os.FindProcess(pid) + if err != nil { + // On Unix systems, FindProcess never returns an error. + // On Windows, for non-existent pids it returns: + // os.SyscallError - 'OpenProcess: the paremter is incorrect' + return nil, ErrProcessNotFound + } + + // Attempt to connect to the addr since on Unix systems FindProcess + // doesn't actually return an error if it can't find the process. + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + p.Kill() + return nil, ErrProcessNotFound + } + conn.Close() + + return &CmdAttachedRunner{ + pid: pid, + process: p, + }, nil + } +} + +// CmdAttachedRunner is mostly a subset of CmdRunner, except the Wait function +// does not assume the process is a child of the host process, and so uses a +// different implementation to wait on the process. +type CmdAttachedRunner struct { + pid int + process *os.Process + + addrTranslator +} + +func (c *CmdAttachedRunner) Wait() error { + return pidWait(c.pid) +} + +func (c *CmdAttachedRunner) Kill() error { + return c.process.Kill() +} + +func (c *CmdAttachedRunner) ID() string { + return fmt.Sprintf("%d", c.pid) +} diff --git a/internal/cmdrunner/cmd_runner.go b/internal/cmdrunner/cmd_runner.go new file mode 100644 index 00000000..722e44d0 --- /dev/null +++ b/internal/cmdrunner/cmd_runner.go @@ -0,0 +1,107 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package cmdrunner + +import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin/runner" +) + +var ( + _ runner.Runner = (*CmdRunner)(nil) + + // ErrProcessNotFound is returned when a client is instantiated to + // reattach to an existing process and it isn't found. + ErrProcessNotFound = errors.New("Reattachment process not found") +) + +// CmdRunner implements the Executor interface. It mostly just passes through +// to exec.Cmd methods. +type CmdRunner struct { + logger hclog.Logger + cmd *exec.Cmd + + stdout io.ReadCloser + stderr io.ReadCloser + + // Cmd info is persisted early, since the process information will be removed + // after Kill is called. + path string + pid int + + addrTranslator +} + +// NewCmdRunner returns an implementation of runner.Runner for running a plugin +// as a subprocess. It must be passed a cmd that hasn't yet been started. +func NewCmdRunner(logger hclog.Logger, cmd *exec.Cmd) (*CmdRunner, error) { + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + return &CmdRunner{ + logger: logger, + cmd: cmd, + stdout: stdout, + stderr: stderr, + path: cmd.Path, + }, nil +} + +func (c *CmdRunner) Start() error { + c.logger.Debug("starting plugin", "path", c.cmd.Path, "args", c.cmd.Args) + err := c.cmd.Start() + if err != nil { + return err + } + + c.pid = c.cmd.Process.Pid + c.logger.Debug("plugin started", "path", c.path, "pid", c.pid) + return nil +} + +func (c *CmdRunner) Wait() error { + return c.cmd.Wait() +} + +func (c *CmdRunner) Kill() error { + if c.cmd.Process != nil { + err := c.cmd.Process.Kill() + // Swallow ErrProcessDone, we support calling Kill multiple times. + if !errors.Is(err, os.ErrProcessDone) { + return err + } + return nil + } + + return nil +} + +func (c *CmdRunner) Stdout() io.ReadCloser { + return c.stdout +} + +func (c *CmdRunner) Stderr() io.ReadCloser { + return c.stderr +} + +func (c *CmdRunner) Name() string { + return c.path +} + +func (c *CmdRunner) ID() string { + return fmt.Sprintf("%d", c.pid) +} diff --git a/internal/cmdrunner/process.go b/internal/cmdrunner/process.go new file mode 100644 index 00000000..6c34dc77 --- /dev/null +++ b/internal/cmdrunner/process.go @@ -0,0 +1,25 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package cmdrunner + +import "time" + +// pidAlive checks whether a pid is alive. +func pidAlive(pid int) bool { + return _pidAlive(pid) +} + +// pidWait blocks for a process to exit. +func pidWait(pid int) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if !pidAlive(pid) { + break + } + } + + return nil +} diff --git a/process_posix.go b/internal/cmdrunner/process_posix.go similarity index 95% rename from process_posix.go rename to internal/cmdrunner/process_posix.go index b73a3607..bf3fc5b6 100644 --- a/process_posix.go +++ b/internal/cmdrunner/process_posix.go @@ -4,7 +4,7 @@ //go:build !windows // +build !windows -package plugin +package cmdrunner import ( "os" diff --git a/process_windows.go b/internal/cmdrunner/process_windows.go similarity index 97% rename from process_windows.go rename to internal/cmdrunner/process_windows.go index ffa9b9e0..6c39df28 100644 --- a/process_windows.go +++ b/internal/cmdrunner/process_windows.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -package plugin +package cmdrunner import ( "syscall" diff --git a/plugin_test.go b/plugin_test.go index d7cc2739..45fdf23c 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -665,6 +665,20 @@ func TestHelperProcess(*testing.T) { // Shouldn't reach here but make sure we exit anyways os.Exit(0) + case "test-skip-host-env-true": + fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion) + if os.Getenv("PLUGIN_TEST_SKIP_HOST_ENV") == "" { + os.Exit(0) + } + + os.Exit(1) + case "test-skip-host-env-false": + fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion) + if os.Getenv("PLUGIN_TEST_SKIP_HOST_ENV") != "" { + os.Exit(0) + } + + os.Exit(1) default: fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) os.Exit(2) diff --git a/process.go b/process.go index 68b028c6..b8844636 100644 --- a/process.go +++ b/process.go @@ -2,26 +2,3 @@ // SPDX-License-Identifier: MPL-2.0 package plugin - -import ( - "time" -) - -// pidAlive checks whether a pid is alive. -func pidAlive(pid int) bool { - return _pidAlive(pid) -} - -// pidWait blocks for a process to exit. -func pidWait(pid int) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for range ticker.C { - if !pidAlive(pid) { - break - } - } - - return nil -} diff --git a/runner/runner.go b/runner/runner.go new file mode 100644 index 00000000..86766f16 --- /dev/null +++ b/runner/runner.go @@ -0,0 +1,64 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package runner + +import ( + "io" +) + +// Runner defines the interface required by go-plugin to manage the lifecycle of +// of a plugin and attempt to negotiate a connection with it. Note that this +// is orthogonal to the protocol and transport used, which is negotiated over stdout. +type Runner interface { + // Start should start the plugin and ensure any context required for servicing + // other interface methods is set up. + Start() error + + // Stdout is used to negotiate the go-plugin protocol. + Stdout() io.ReadCloser + + // Stderr is used for forwarding plugin logs to the host process logger. + Stderr() io.ReadCloser + + // Name is a human-friendly name for the plugin, such as the path to the + // executable. It does not have to be unique. + Name() string + + AttachedRunner +} + +// AttachedRunner defines a limited subset of Runner's interface to represent the +// reduced responsibility for plugin lifecycle when attaching to an already running +// plugin. +type AttachedRunner interface { + // Wait should wait until the plugin stops running, whether in response to + // an out of band signal or in response to calling Kill(). + Wait() error + + // Kill should stop the plugin and perform any cleanup required. + Kill() error + + // ID is a unique identifier to represent the running plugin. e.g. pid or + // container ID. + ID() string + + AddrTranslator +} + +// AddrTranslator translates addresses between the execution context of the host +// process and the plugin. For example, if the plugin is in a container, the file +// path for a Unix socket may be different between the host and the container. +// +// It is only intended to be used by the host process. +type AddrTranslator interface { + // Called before connecting on any addresses received back from the plugin. + PluginToHost(pluginNet, pluginAddr string) (hostNet string, hostAddr string, err error) + + // Called on any host process addresses before they are sent to the plugin. + HostToPlugin(hostNet, hostAddr string) (pluginNet string, pluginAddr string, err error) +} + +// ReattachFunc can be passed to a client's reattach config to reattach to an +// already running plugin instead of starting it ourselves. +type ReattachFunc func() (AttachedRunner, error) diff --git a/server.go b/server.go index 3f4a017d..4e9a22c0 100644 --- a/server.go +++ b/server.go @@ -11,10 +11,10 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "os" "os/signal" + "os/user" "runtime" "sort" "strconv" @@ -273,7 +273,7 @@ func Serve(opts *ServeConfig) { } // Register a listener so we can accept a connection - listener, err := serverListener() + listener, err := serverListener(os.Getenv(EnvUnixSocketDir)) if err != nil { logger.Error("plugin init error", "error", err) return @@ -496,12 +496,12 @@ func Serve(opts *ServeConfig) { } } -func serverListener() (net.Listener, error) { +func serverListener(dir string) (net.Listener, error) { if runtime.GOOS == "windows" { return serverListener_tcp() } - return serverListener_unix() + return serverListener_unix(dir) } func serverListener_tcp() (net.Listener, error) { @@ -546,8 +546,8 @@ func serverListener_tcp() (net.Listener, error) { return nil, errors.New("Couldn't bind plugin TCP listener") } -func serverListener_unix() (net.Listener, error) { - tf, err := ioutil.TempFile("", "plugin") +func serverListener_unix(dir string) (net.Listener, error) { + tf, err := os.CreateTemp(dir, "plugin") if err != nil { return nil, err } @@ -567,6 +567,32 @@ func serverListener_unix() (net.Listener, error) { return nil, err } + // By default, unix sockets are only writable by the owner. Set up a custom + // group owner and group write permissions if configured. + if groupString := os.Getenv(EnvUnixSocketGroup); groupString != "" { + groupID, err := strconv.Atoi(groupString) + if err != nil { + group, err := user.LookupGroup(groupString) + if err != nil { + return nil, fmt.Errorf("failed to find group ID from %s=%s environment variable: %w", EnvUnixSocketGroup, groupString, err) + } + groupID, err = strconv.Atoi(group.Gid) + if err != nil { + return nil, fmt.Errorf("failed to parse %q group's Gid as an integer: %w", groupString, err) + } + } + + err = os.Chown(path, -1, groupID) + if err != nil { + return nil, err + } + + err = os.Chmod(path, 0o660) + if err != nil { + return nil, err + } + } + // Wrap the listener in rmListener so that the Unix domain socket file // is removed on close. return &rmListener{ diff --git a/server_test.go b/server_test.go index d446dfdf..68161ecc 100644 --- a/server_test.go +++ b/server_test.go @@ -6,10 +6,11 @@ package plugin import ( "bytes" "context" - "io/ioutil" "log" "net" "os" + "path" + "runtime" "strings" "testing" "time" @@ -195,7 +196,7 @@ func TestRmListener(t *testing.T) { t.Fatalf("err: %s", err) } - tf, err := ioutil.TempFile("", "plugin") + tf, err := os.CreateTemp("", "plugin") if err != nil { t.Fatalf("err: %s", err) } @@ -308,3 +309,50 @@ func TestServer_testStdLogger(t *testing.T) { t.Fatalf("expected: %q\ngot: %q", "test log", logOut.String()) } } + +func TestUnixSocketDir(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("go-plugin doesn't support unix sockets on Windows") + } + + tmpDir := t.TempDir() + t.Setenv(EnvUnixSocketDir, tmpDir) + + closeCh := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // make a server, but we don't need to attach to it + ch := make(chan *ReattachConfig, 1) + go Serve(&ServeConfig{ + HandshakeConfig: testHandshake, + Plugins: testGRPCPluginMap, + GRPCServer: DefaultGRPCServer, + Logger: hclog.NewNullLogger(), + Test: &ServeTestConfig{ + Context: ctx, + CloseCh: closeCh, + ReattachConfigCh: ch, + }, + }) + + // Wait for the server + var cfg *ReattachConfig + select { + case cfg = <-ch: + if cfg == nil { + t.Fatal("attach config should not be nil") + } + case <-time.After(2000 * time.Millisecond): + t.Fatal("should've received reattach") + } + + actualDir := path.Clean(path.Dir(cfg.Addr.String())) + expectedDir := path.Clean(tmpDir) + if actualDir != expectedDir { + t.Fatalf("Expected socket in dir: %s, but was in %s", expectedDir, actualDir) + } + + cancel() + <-closeCh +} diff --git a/server_unix_test.go b/server_unix_test.go new file mode 100644 index 00000000..1de10a1f --- /dev/null +++ b/server_unix_test.go @@ -0,0 +1,58 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !windows +// +build !windows + +package plugin + +import ( + "fmt" + "os" + "os/user" + "runtime" + "syscall" + "testing" +) + +func TestUnixSocketGroupPermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("go-plugin doesn't support unix sockets on Windows") + } + + group, err := user.LookupGroupId(fmt.Sprintf("%d", os.Getgid())) + if err != nil { + t.Fatal(err) + } + for name, tc := range map[string]struct { + gid string + }{ + "as integer": {fmt.Sprintf("%d", os.Getgid())}, + "as name": {group.Name}, + } { + t.Run(name, func(t *testing.T) { + t.Setenv(EnvUnixSocketGroup, tc.gid) + + ln, err := serverListener_unix("") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + info, err := os.Lstat(ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + if info.Mode()&os.ModePerm != 0o660 { + t.Fatal(info.Mode()) + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal() + } + if stat.Gid != uint32(os.Getgid()) { + t.Fatalf("Expected %d, but got %d", os.Getgid(), stat.Gid) + } + }) + } +} diff --git a/testing.go b/testing.go index ffe6fa46..27e05f01 100644 --- a/testing.go +++ b/testing.go @@ -166,7 +166,7 @@ func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCSe } brokerGRPCClient := newGRPCBrokerClient(conn) - broker := newGRPCBroker(brokerGRPCClient, nil) + broker := newGRPCBroker(brokerGRPCClient, nil, "", nil) go broker.Run() go brokerGRPCClient.StartStream()