diff --git a/cli/cmd/agent_install.go b/cli/cmd/agent_install.go index 925113393..c2c35164b 100644 --- a/cli/cmd/agent_install.go +++ b/cli/cmd/agent_install.go @@ -24,11 +24,9 @@ import ( "io/ioutil" "net" "net/http" - "path" "strings" "github.com/AlecAivazis/survey/v2" - homedir "github.com/mitchellh/go-homedir" "github.com/olekukonko/tablewriter" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -51,10 +49,7 @@ func installRemoteAgent(_ *cobra.Command, args []string) error { } cli.Log.Debugw("creating runner", "user", user, "host", host) - runner, err := lwrunner.New(user, host, verifyHostCallback) - if err != nil { - return errors.Wrap(err, "unable to initialize lwrunner") - } + runner := lwrunner.New(user, host, verifyHostCallback) if runner.User == "" { cli.Log.Debugw("ssh username not set") @@ -85,7 +80,7 @@ func installRemoteAgent(_ *cobra.Command, args []string) error { // if no authentication was set if !authSet { // try to use the default identity file - identityFile, err := defaultIdentityFile() + identityFile, err := lwrunner.DefaultIdentityFilePath() if err != nil { return err } @@ -287,15 +282,6 @@ func askForUsername() (string, error) { return user, nil } -func defaultIdentityFile() (string, error) { - home, err := homedir.Dir() - if err != nil { - return "", err - } - - return path.Join(home, ".ssh", "id_rsa"), nil -} - func verifyHostCallback(host string, remote net.Addr, key ssh.PublicKey) error { // error if key does not exist inside the default known_hosts file, // or if host in known_hosts file but key changed! diff --git a/cli/cmd/agent_install_test.go b/cli/cmd/agent_install_test.go index f13515437..4147292b4 100644 --- a/cli/cmd/agent_install_test.go +++ b/cli/cmd/agent_install_test.go @@ -46,14 +46,6 @@ func TestLatestAgentVersionSHA(t *testing.T) { } } -func TestDefaultIdentityFile(t *testing.T) { - subject, err := defaultIdentityFile() - if assert.Nil(t, err) { - assert.Contains(t, subject, ".ssh") - assert.Contains(t, subject, "id_rsa") - } -} - func TestFormatRunnerError(t *testing.T) { cases := []struct { expected error diff --git a/lwrunner/runner.go b/lwrunner/runner.go index cad91da42..231280dbb 100644 --- a/lwrunner/runner.go +++ b/lwrunner/runner.go @@ -39,28 +39,25 @@ type Runner struct { *ssh.ClientConfig } -func New(user, host string, callback ssh.HostKeyCallback) (*Runner, error) { - var err error - if callback == nil { - callback, err = DefaultKnownHosts() - if err != nil { - return nil, err - } - } - +func New(user, host string, callback ssh.HostKeyCallback) *Runner { if os.Getenv("LW_SSH_USER") != "" { user = os.Getenv("LW_SSH_USER") } + defaultCallback, err := DefaultKnownHosts() + if err == nil && callback == nil { + callback = defaultCallback + } + return &Runner{ host, 22, &ssh.ClientConfig{ User: user, - Auth: []ssh.AuthMethod{defaultAuthMethod()}, + Auth: []ssh.AuthMethod{}, HostKeyCallback: callback, }, - }, nil + } } func (run Runner) UseIdentityFile(file string) error { @@ -116,7 +113,7 @@ func DefaultKnownHostsPath() (string, error) { return "", err } - return path.Join(home, ".ssh", "known_hosts"), err + return path.Join(home, ".ssh", "known_hosts"), nil } // AddKnownHost adds a host to the provided known hosts file, if no known hosts @@ -194,6 +191,19 @@ func CheckKnownHost(host string, remote net.Addr, key ssh.PublicKey, knownFile s return false, nil } +func DefaultIdentityFilePath() (string, error) { + if os.Getenv("LW_SSH_IDENTITY_FILE") != "" { + return os.Getenv("LW_SSH_IDENTITY_FILE"), nil + } + + home, err := homedir.Dir() + if err != nil { + return "", err + } + + return path.Join(home, ".ssh", "id_rsa"), nil +} + func newSignerFromFile(keyname string) (ssh.Signer, error) { fp, err := os.Open(keyname) if err != nil { @@ -208,27 +218,3 @@ func newSignerFromFile(keyname string) (ssh.Signer, error) { return ssh.ParsePrivateKey(buf) } - -func defaultAuthMethod() ssh.AuthMethod { - var ( - signers = []ssh.Signer{} - keys = []string{} - ) - home, err := homedir.Dir() - if err == nil { - keys = append(keys, path.Join(home, ".ssh", "id_rsa")) - } - - if os.Getenv("LW_SSH_IDENTITY_FILE") != "" { - keys = append(keys, os.Getenv("LW_SSH_IDENTITY_FILE")) - } - - for _, keyname := range keys { - signer, err := newSignerFromFile(keyname) - if err == nil { - signers = append(signers, signer) - } - } - - return ssh.PublicKeys(signers...) -} diff --git a/lwrunner/runner_test.go b/lwrunner/runner_test.go index 260a28bd9..c62316107 100644 --- a/lwrunner/runner_test.go +++ b/lwrunner/runner_test.go @@ -32,51 +32,24 @@ import ( ) func TestLwRunnerNew(t *testing.T) { - // we use the default know host file inside the HOME directory - // of the current user, that is why we need to mock it - mockHome, err := ioutil.TempDir("", "lwrunner") - if err != nil { - panic(err) - } - defer os.RemoveAll(mockHome) - - err = os.Mkdir(path.Join(mockHome, ".ssh"), 0755) - if err != nil { - panic(err) - } - - err = ioutil.WriteFile(path.Join(mockHome, ".ssh", "known_hosts"), []byte(""), 0600) - if err != nil { - panic(err) - } - homeCache := os.Getenv("HOME") - os.Setenv("HOME", mockHome) - defer os.Setenv("HOME", homeCache) - - subject, err := lwrunner.New("root", "192.1.1.2", nil) - if assert.Nil(t, err) { - assert.Equal(t, 22, subject.Port) - assert.Equal(t, "root", subject.User) - assert.Equal(t, "192.1.1.2", subject.Hostname) - } + subject := lwrunner.New("root", "192.1.1.2", nil) + assert.Equal(t, 22, subject.Port) + assert.Equal(t, "root", subject.User) + assert.Equal(t, "192.1.1.2", subject.Hostname) } func TestLwRunnerNewIgnoreHostKey(t *testing.T) { - subject, err := lwrunner.New("ubuntu", "my-test-host", ssh.InsecureIgnoreHostKey()) - if assert.Nil(t, err) { - assert.Equal(t, 22, subject.Port) - assert.Equal(t, "ubuntu", subject.User) - assert.Equal(t, "my-test-host", subject.Hostname) - } + subject := lwrunner.New("ubuntu", "my-test-host", ssh.InsecureIgnoreHostKey()) + assert.Equal(t, 22, subject.Port) + assert.Equal(t, "ubuntu", subject.User) + assert.Equal(t, "my-test-host", subject.Hostname) } func TestLwRunnerNewCustomCallback(t *testing.T) { - subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback) - if assert.Nil(t, err) { - assert.Equal(t, 22, subject.Port) - assert.Equal(t, subject.User, "ec2-user") - assert.Equal(t, subject.Hostname, "host.example.com") - } + subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback) + assert.Equal(t, 22, subject.Port) + assert.Equal(t, subject.User, "ec2-user") + assert.Equal(t, subject.Hostname, "host.example.com") } // test function to mock host callback @@ -88,33 +61,26 @@ func TestLwRunnerNewUserEnvVariable(t *testing.T) { os.Setenv("LW_SSH_USER", "root") defer os.Setenv("LW_SSH_USER", "") - subject, err := lwrunner.New("ubuntu", "a-test-host", ssh.InsecureIgnoreHostKey()) - if assert.Nil(t, err) { - assert.Equal(t, subject.User, "root") - assert.Equal(t, subject.Hostname, "a-test-host") - } + subject := lwrunner.New("ubuntu", "a-test-host", ssh.InsecureIgnoreHostKey()) + assert.Equal(t, subject.User, "root") + assert.Equal(t, subject.Hostname, "a-test-host") } func TestLwRunnerUsePassword(t *testing.T) { - subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback) - if assert.Nil(t, err) { - assert.Equal(t, "ec2-user", subject.User) - assert.Equal(t, "host.example.com:22", subject.Address()) - } + subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback) + assert.Equal(t, "ec2-user", subject.User) + assert.Equal(t, "host.example.com:22", subject.Address()) subject.UsePassword("secret123") - assert.Equal(t, 1, len(subject.Auth)) } func TestLwRunnerUseIdentityFile(t *testing.T) { - subject, err := lwrunner.New("ec2-user", "host.example.com", customHostCallback) - if assert.Nil(t, err) { - assert.Equal(t, "ec2-user", subject.User) - assert.Equal(t, "host.example.com:22", subject.Address()) - } + subject := lwrunner.New("ec2-user", "host.example.com", customHostCallback) + assert.Equal(t, "ec2-user", subject.User) + assert.Equal(t, "host.example.com:22", subject.Address()) - err = subject.UseIdentityFile("file-not-found") + err := subject.UseIdentityFile("file-not-found") assert.NotNil(t, err) } @@ -124,3 +90,70 @@ func TestLwRunnerDefaultKnownHostsPath(t *testing.T) { assert.Contains(t, subject, ".ssh/known_hosts") } } + +func TestDefaultIdentityFilePath(t *testing.T) { + subject, err := lwrunner.DefaultIdentityFilePath() + if assert.Nil(t, err) { + assert.Contains(t, subject, ".ssh") + assert.Contains(t, subject, "id_rsa") + } +} + +func TestDefaultIdentityFilePathEnvVariable(t *testing.T) { + expected := "/pat/to/key" + os.Setenv("LW_SSH_IDENTITY_FILE", expected) + defer os.Setenv("LW_SSH_IDENTITY_FILE", "") + + subject, err := lwrunner.DefaultIdentityFilePath() + if assert.Nil(t, err) { + assert.Equal(t, subject, expected) + } +} + +func TestLwRunnerDefaultKnownHosts(t *testing.T) { + mockHome, err := ioutil.TempDir("", "lwrunner") + if err != nil { + panic(err) + } + defer os.RemoveAll(mockHome) + + err = os.Mkdir(path.Join(mockHome, ".ssh"), 0755) + if err != nil { + panic(err) + } + + err = ioutil.WriteFile(path.Join(mockHome, ".ssh", "known_hosts"), []byte(""), 0600) + if err != nil { + panic(err) + } + homeCache := os.Getenv("HOME") + os.Setenv("HOME", mockHome) + defer os.Setenv("HOME", homeCache) + + subject, err := lwrunner.DefaultKnownHosts() + assert.NotNil(t, subject) + if assert.Nil(t, err) { + assert.NotNil(t, subject("mock.hostname.example.com:22", mockAddr{}, mockPublicKey{})) + } +} + +type mockAddr struct{} + +func (m mockAddr) Network() string { + return "tcp" +} +func (m mockAddr) String() string { + return "mock.hostname.example.com:22" +} + +type mockPublicKey struct{} + +func (m mockPublicKey) Type() string { + return "ssh-rsa" +} +func (m mockPublicKey) Marshal() []byte { + return []byte{} +} +func (m mockPublicKey) Verify(_ []byte, _ *ssh.Signature) error { + return nil +}