Skip to content

Commit

Permalink
test(cli): increase agent install test coverage (#276)
Browse files Browse the repository at this point in the history
Increasing the test coverage of our `lwrunner` package.

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune committed Jan 4, 2021
1 parent 296be65 commit da5b4ae
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 116 deletions.
18 changes: 2 additions & 16 deletions cli/cmd/agent_install.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ import (
"io/ioutil"
"net"
"net/http"
"path"
"strings"

"github.com/AlecAivazis/survey/v2"
homedir "github.com/mitchellh/go-homedir"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"
"github.com/spf13/cobra"
Expand All @@ -51,10 +49,7 @@ func installRemoteAgent(_ *cobra.Command, args []string) error {
}

cli.Log.Debugw("creating runner", "user", user, "host", host)
runner, err := lwrunner.New(user, host, verifyHostCallback)
if err != nil {
return errors.Wrap(err, "unable to initialize lwrunner")
}
runner := lwrunner.New(user, host, verifyHostCallback)

if runner.User == "" {
cli.Log.Debugw("ssh username not set")
Expand Down Expand Up @@ -85,7 +80,7 @@ func installRemoteAgent(_ *cobra.Command, args []string) error {
// if no authentication was set
if !authSet {
// try to use the default identity file
identityFile, err := defaultIdentityFile()
identityFile, err := lwrunner.DefaultIdentityFilePath()
if err != nil {
return err
}
Expand Down Expand Up @@ -287,15 +282,6 @@ func askForUsername() (string, error) {
return user, nil
}

func defaultIdentityFile() (string, error) {
home, err := homedir.Dir()
if err != nil {
return "", err
}

return path.Join(home, ".ssh", "id_rsa"), nil
}

func verifyHostCallback(host string, remote net.Addr, key ssh.PublicKey) error {
// error if key does not exist inside the default known_hosts file,
// or if host in known_hosts file but key changed!
Expand Down
8 changes: 0 additions & 8 deletions cli/cmd/agent_install_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ func TestLatestAgentVersionSHA(t *testing.T) {
}
}

func TestDefaultIdentityFile(t *testing.T) {
subject, err := defaultIdentityFile()
if assert.Nil(t, err) {
assert.Contains(t, subject, ".ssh")
assert.Contains(t, subject, "id_rsa")
}
}

func TestFormatRunnerError(t *testing.T) {
cases := []struct {
expected error
Expand Down
58 changes: 22 additions & 36 deletions lwrunner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,25 @@ type Runner struct {
*ssh.ClientConfig
}

func New(user, host string, callback ssh.HostKeyCallback) (*Runner, error) {
var err error
if callback == nil {
callback, err = DefaultKnownHosts()
if err != nil {
return nil, err
}
}

func New(user, host string, callback ssh.HostKeyCallback) *Runner {
if os.Getenv("LW_SSH_USER") != "" {
user = os.Getenv("LW_SSH_USER")
}

defaultCallback, err := DefaultKnownHosts()
if err == nil && callback == nil {
callback = defaultCallback
}

return &Runner{
host,
22,
&ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{defaultAuthMethod()},
Auth: []ssh.AuthMethod{},
HostKeyCallback: callback,
},
}, nil
}
}

func (run Runner) UseIdentityFile(file string) error {
Expand Down Expand Up @@ -116,7 +113,7 @@ func DefaultKnownHostsPath() (string, error) {
return "", err
}

return path.Join(home, ".ssh", "known_hosts"), err
return path.Join(home, ".ssh", "known_hosts"), nil
}

// AddKnownHost adds a host to the provided known hosts file, if no known hosts
Expand Down Expand Up @@ -194,6 +191,19 @@ func CheckKnownHost(host string, remote net.Addr, key ssh.PublicKey, knownFile s
return false, nil
}

func DefaultIdentityFilePath() (string, error) {
if os.Getenv("LW_SSH_IDENTITY_FILE") != "" {
return os.Getenv("LW_SSH_IDENTITY_FILE"), nil
}

home, err := homedir.Dir()
if err != nil {
return "", err
}

return path.Join(home, ".ssh", "id_rsa"), nil
}

func newSignerFromFile(keyname string) (ssh.Signer, error) {
fp, err := os.Open(keyname)
if err != nil {
Expand All @@ -208,27 +218,3 @@ func newSignerFromFile(keyname string) (ssh.Signer, error) {

return ssh.ParsePrivateKey(buf)
}

func defaultAuthMethod() ssh.AuthMethod {
var (
signers = []ssh.Signer{}
keys = []string{}
)
home, err := homedir.Dir()
if err == nil {
keys = append(keys, path.Join(home, ".ssh", "id_rsa"))
}

if os.Getenv("LW_SSH_IDENTITY_FILE") != "" {
keys = append(keys, os.Getenv("LW_SSH_IDENTITY_FILE"))
}

for _, keyname := range keys {
signer, err := newSignerFromFile(keyname)
if err == nil {
signers = append(signers, signer)
}
}

return ssh.PublicKeys(signers...)
}
145 changes: 89 additions & 56 deletions lwrunner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,24 @@ import (
)

func TestLwRunnerNew(t *testing.T) {
// we use the default know host file inside the HOME directory
// of the current user, that is why we need to mock it
mockHome, err := ioutil.TempDir("", "lwrunner")
if err != nil {
panic(err)
}
defer os.RemoveAll(mockHome)

err = os.Mkdir(path.Join(mockHome, ".ssh"), 0755)
if err != nil {
panic(err)
}

err = ioutil.WriteFile(path.Join(mockHome, ".ssh", "known_hosts"), []byte(""), 0600)
if err != nil {
panic(err)
}
homeCache := os.Getenv("HOME")
os.Setenv("HOME", mockHome)
defer os.Setenv("HOME", homeCache)

subject, err := lwrunner.New("root", "192.1.1.2", nil)
if assert.Nil(t, err) {
assert.Equal(t, 22, subject.Port)
assert.Equal(t, "root", subject.User)
assert.Equal(t, "192.1.1.2", subject.Hostname)
}
subject := lwrunner.New("root", "192.1.1.2", nil)
assert.Equal(t, 22, subject.Port)
assert.Equal(t, "root", subject.User)
assert.Equal(t, "192.1.1.2", subject.Hostname)
}

func TestLwRunnerNewIgnoreHostKey(t *testing.T) {
subject, err := lwrunner.New("ubuntu", "my-test-host", ssh.InsecureIgnoreHostKey())
if assert.Nil(t, err) {
assert.Equal(t, 22, subject.Port)
assert.Equal(t, "ubuntu", subject.User)
assert.Equal(t, "my-test-host", subject.Hostname)
}
subject := lwrunner.New("ubuntu", "my-test-host", ssh.InsecureIgnoreHostKey())
assert.Equal(t, 22, subject.Port)
assert.Equal(t, "ubuntu", subject.User)
assert.Equal(t, "my-test-host", subject.Hostname)
}

func TestLwRunnerNewCustomCallback(t *testing.T) {
subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
if assert.Nil(t, err) {
assert.Equal(t, 22, subject.Port)
assert.Equal(t, subject.User, "ec2-user")
assert.Equal(t, subject.Hostname, "host.example.com")
}
subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
assert.Equal(t, 22, subject.Port)
assert.Equal(t, subject.User, "ec2-user")
assert.Equal(t, subject.Hostname, "host.example.com")
}

// test function to mock host callback
Expand All @@ -88,33 +61,26 @@ func TestLwRunnerNewUserEnvVariable(t *testing.T) {
os.Setenv("LW_SSH_USER", "root")
defer os.Setenv("LW_SSH_USER", "")

subject, err := lwrunner.New("ubuntu", "a-test-host", ssh.InsecureIgnoreHostKey())
if assert.Nil(t, err) {
assert.Equal(t, subject.User, "root")
assert.Equal(t, subject.Hostname, "a-test-host")
}
subject := lwrunner.New("ubuntu", "a-test-host", ssh.InsecureIgnoreHostKey())
assert.Equal(t, subject.User, "root")
assert.Equal(t, subject.Hostname, "a-test-host")
}

func TestLwRunnerUsePassword(t *testing.T) {
subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
if assert.Nil(t, err) {
assert.Equal(t, "ec2-user", subject.User)
assert.Equal(t, "host.example.com:22", subject.Address())
}
subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
assert.Equal(t, "ec2-user", subject.User)
assert.Equal(t, "host.example.com:22", subject.Address())

subject.UsePassword("secret123")

assert.Equal(t, 1, len(subject.Auth))
}

func TestLwRunnerUseIdentityFile(t *testing.T) {
subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
if assert.Nil(t, err) {
assert.Equal(t, "ec2-user", subject.User)
assert.Equal(t, "host.example.com:22", subject.Address())
}
subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback)
assert.Equal(t, "ec2-user", subject.User)
assert.Equal(t, "host.example.com:22", subject.Address())

err = subject.UseIdentityFile("file-not-found")
err := subject.UseIdentityFile("file-not-found")
assert.NotNil(t, err)
}

Expand All @@ -124,3 +90,70 @@ func TestLwRunnerDefaultKnownHostsPath(t *testing.T) {
assert.Contains(t, subject, ".ssh/known_hosts")
}
}

func TestDefaultIdentityFilePath(t *testing.T) {
subject, err := lwrunner.DefaultIdentityFilePath()
if assert.Nil(t, err) {
assert.Contains(t, subject, ".ssh")
assert.Contains(t, subject, "id_rsa")
}
}

func TestDefaultIdentityFilePathEnvVariable(t *testing.T) {
expected := "/pat/to/key"
os.Setenv("LW_SSH_IDENTITY_FILE", expected)
defer os.Setenv("LW_SSH_IDENTITY_FILE", "")

subject, err := lwrunner.DefaultIdentityFilePath()
if assert.Nil(t, err) {
assert.Equal(t, subject, expected)
}
}

func TestLwRunnerDefaultKnownHosts(t *testing.T) {
mockHome, err := ioutil.TempDir("", "lwrunner")
if err != nil {
panic(err)
}
defer os.RemoveAll(mockHome)

err = os.Mkdir(path.Join(mockHome, ".ssh"), 0755)
if err != nil {
panic(err)
}

err = ioutil.WriteFile(path.Join(mockHome, ".ssh", "known_hosts"), []byte(""), 0600)
if err != nil {
panic(err)
}
homeCache := os.Getenv("HOME")
os.Setenv("HOME", mockHome)
defer os.Setenv("HOME", homeCache)

subject, err := lwrunner.DefaultKnownHosts()
assert.NotNil(t, subject)
if assert.Nil(t, err) {
assert.NotNil(t, subject("mock.hostname.example.com:22", mockAddr{}, mockPublicKey{}))
}
}

type mockAddr struct{}

func (m mockAddr) Network() string {
return "tcp"
}
func (m mockAddr) String() string {
return "mock.hostname.example.com:22"
}

type mockPublicKey struct{}

func (m mockPublicKey) Type() string {
return "ssh-rsa"
}
func (m mockPublicKey) Marshal() []byte {
return []byte{}
}
func (m mockPublicKey) Verify(_ []byte, _ *ssh.Signature) error {
return nil
}

0 comments on commit da5b4ae

Please sign in to comment.