Skip to content

Commit

Permalink
fix(lwrunner): create ~/.ssh directory if not exist (#933)
Browse files Browse the repository at this point in the history
**Summary**

This change fixes the issue where we try to add a known host to the
known_hosts file but the directory ~/.ssh does not exist.

Now we are creating that directory first.

**How did you test this change?**

Replicated the issue inside our unit tests:
```
    === Failed
    === FAIL: lwrunner TestLwRunnerAddKnownHost (0.19s)
        runner_test.go:132:
                    Error Trace:    /Users/afiune/github/go-sdk/lwrunner/runner_test.go:132
                    Error:          Received unexpected error:
                                    open /var/folders/pn/rtky3yx17fx60dc03njhm_fw0000gn/T/lwrunner1052124622/.ssh/known_hosts: no such file or directory
                    Test:           TestLwRunnerAddKnownHost
```


**Issue**

Closes #915
Jira https://lacework.atlassian.net/browse/ALLY-1195

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune committed Oct 7, 2022
1 parent 56cbdcd commit bfc9099
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lwrunner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"os"
"path"

"github.com/lacework/go-sdk/internal/file"
homedir "github.com/mitchellh/go-homedir"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
Expand Down Expand Up @@ -128,6 +129,12 @@ func AddKnownHost(host string, remote net.Addr, key ssh.PublicKey, knownFile str
knownFile = path
}

if !file.FileExists(knownFile) {
if err := os.MkdirAll(path.Dir(knownFile), 0700); err != nil {
return err
}
}

f, err := os.OpenFile(knownFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
return err
Expand Down
79 changes: 79 additions & 0 deletions lwrunner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package lwrunner_test

import (
"crypto/rand"
"crypto/rsa"
"io/ioutil"
"net"
"os"
Expand Down Expand Up @@ -111,6 +113,83 @@ func TestDefaultIdentityFilePathEnvVariable(t *testing.T) {
}
}

func TestLwRunnerAddKnownHostNoSSHDir(t *testing.T) {
mockHome, err := ioutil.TempDir("", "lwrunner")
if err != nil {
panic(err)
}
defer os.RemoveAll(mockHome)

knownFile := path.Join(mockHome, ".ssh", "known_hosts")
netAddr := mockNetAddr{}
// generate test RSA keypair in SSH format
priv, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
rsaPub := priv.PublicKey
sshPub, err := ssh.NewPublicKey(&rsaPub)
assert.NoError(t, err)

// Add known host to mocked home directory
subject := lwrunner.AddKnownHost("mock-test", netAddr, sshPub, knownFile)
assert.NoError(t, subject)

// Check the known host file
content, err := ioutil.ReadFile(knownFile)
assert.NoError(t, err)
assert.Contains(t, string(content), "mock-test")

// Try again, it should work
// Add known host to mocked home directory
subject = lwrunner.AddKnownHost("second-time", netAddr, sshPub, knownFile)
assert.NoError(t, subject)

// Check the known host file
content, err = ioutil.ReadFile(knownFile)
assert.NoError(t, err)
assert.Contains(t, string(content), "second-time")
}

func TestLwRunnerAddKnownWithSSHDir(t *testing.T) {
mockHome, err := ioutil.TempDir("", "lwrunner")
if err != nil {
panic(err)
}
defer os.RemoveAll(mockHome)

// Mock that the ~/.ssh dir exists
err = os.Mkdir(path.Join(mockHome, ".ssh"), 0700)
if err != nil {
panic(err)
}

knownFile := path.Join(mockHome, ".ssh", "known_hosts")
netAddr := mockNetAddr{}
// generate test RSA keypair in SSH format
priv, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
rsaPub := priv.PublicKey
sshPub, err := ssh.NewPublicKey(&rsaPub)
assert.NoError(t, err)

// Add known host to mocked home directory
subject := lwrunner.AddKnownHost("mock-test", netAddr, sshPub, knownFile)
assert.NoError(t, subject)

// Check the known host file
content, err := ioutil.ReadFile(knownFile)
assert.NoError(t, err)
assert.Contains(t, string(content), "mock-test")
}

type mockNetAddr struct{}

func (m mockNetAddr) Network() string {
return "tcp"
}
func (m mockNetAddr) String() string {
return "1.1.1.1"
}

func TestLwRunnerDefaultKnownHosts(t *testing.T) {
mockHome, err := ioutil.TempDir("", "lwrunner")
if err != nil {
Expand Down

0 comments on commit bfc9099

Please sign in to comment.