diff --git a/lwrunner/runner.go b/lwrunner/runner.go index 231280dbb..1716e6a7e 100644 --- a/lwrunner/runner.go +++ b/lwrunner/runner.go @@ -28,6 +28,7 @@ import ( "os" "path" + "github.com/lacework/go-sdk/internal/file" homedir "github.com/mitchellh/go-homedir" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" @@ -128,6 +129,12 @@ func AddKnownHost(host string, remote net.Addr, key ssh.PublicKey, knownFile str knownFile = path } + if !file.FileExists(knownFile) { + if err := os.MkdirAll(path.Dir(knownFile), 0700); err != nil { + return err + } + } + f, err := os.OpenFile(knownFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { return err diff --git a/lwrunner/runner_test.go b/lwrunner/runner_test.go index 288bade26..cdebcbde3 100644 --- a/lwrunner/runner_test.go +++ b/lwrunner/runner_test.go @@ -19,6 +19,8 @@ package lwrunner_test import ( + "crypto/rand" + "crypto/rsa" "io/ioutil" "net" "os" @@ -111,6 +113,83 @@ func TestDefaultIdentityFilePathEnvVariable(t *testing.T) { } } +func TestLwRunnerAddKnownHostNoSSHDir(t *testing.T) { + mockHome, err := ioutil.TempDir("", "lwrunner") + if err != nil { + panic(err) + } + defer os.RemoveAll(mockHome) + + knownFile := path.Join(mockHome, ".ssh", "known_hosts") + netAddr := mockNetAddr{} + // generate test RSA keypair in SSH format + priv, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + rsaPub := priv.PublicKey + sshPub, err := ssh.NewPublicKey(&rsaPub) + assert.NoError(t, err) + + // Add known host to mocked home directory + subject := lwrunner.AddKnownHost("mock-test", netAddr, sshPub, knownFile) + assert.NoError(t, subject) + + // Check the known host file + content, err := ioutil.ReadFile(knownFile) + assert.NoError(t, err) + assert.Contains(t, string(content), "mock-test") + + // Try again, it should work + // Add known host to mocked home directory + subject = lwrunner.AddKnownHost("second-time", netAddr, sshPub, knownFile) + assert.NoError(t, subject) + + // Check the known host file + content, err = ioutil.ReadFile(knownFile) + assert.NoError(t, err) + assert.Contains(t, string(content), "second-time") +} + +func TestLwRunnerAddKnownWithSSHDir(t *testing.T) { + mockHome, err := ioutil.TempDir("", "lwrunner") + if err != nil { + panic(err) + } + defer os.RemoveAll(mockHome) + + // Mock that the ~/.ssh dir exists + err = os.Mkdir(path.Join(mockHome, ".ssh"), 0700) + if err != nil { + panic(err) + } + + knownFile := path.Join(mockHome, ".ssh", "known_hosts") + netAddr := mockNetAddr{} + // generate test RSA keypair in SSH format + priv, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + rsaPub := priv.PublicKey + sshPub, err := ssh.NewPublicKey(&rsaPub) + assert.NoError(t, err) + + // Add known host to mocked home directory + subject := lwrunner.AddKnownHost("mock-test", netAddr, sshPub, knownFile) + assert.NoError(t, subject) + + // Check the known host file + content, err := ioutil.ReadFile(knownFile) + assert.NoError(t, err) + assert.Contains(t, string(content), "mock-test") +} + +type mockNetAddr struct{} + +func (m mockNetAddr) Network() string { + return "tcp" +} +func (m mockNetAddr) String() string { + return "1.1.1.1" +} + func TestLwRunnerDefaultKnownHosts(t *testing.T) { mockHome, err := ioutil.TempDir("", "lwrunner") if err != nil {