Skip to content

Commit

Permalink
refactor(cli): fixup + test aws-install SSH user (#1009)
Browse files Browse the repository at this point in the history
This commit refactors and adds unit tests for the SSH username
assignment logic in the `aws-install` CLI subcommand.

Specifically, this commit:
* Breaks down the AMI image name fetch and the lookup table code into
  helper functions
* Adds unit tests for these helper functions
* Enforces using the username passed as the `--ssh_username` CLI arg

Fixes RAIN-42812

Signed-off-by: nschmeller <[email protected]>
  • Loading branch information
nschmeller authored Nov 9, 2022
1 parent 6dbcb23 commit d216771
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 37 deletions.
1 change: 1 addition & 0 deletions cli/cmd/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ func awsRegionDescribeInstances(region string) ([]*lwrunner.AWSRunner, error) {

runner, err := lwrunner.NewAWSRunner(
*threadInstance.ImageId,
agentCmdState.InstallSshUser,
*threadInstance.PublicIpAddress,
region,
*threadInstance.Placement.AvailabilityZone,
Expand Down
105 changes: 68 additions & 37 deletions lwrunner/awsrunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,56 +38,25 @@ type AWSRunner struct {
InstanceID string
}

func NewAWSRunner(amiImageId, host, region, availabilityZone, instanceID string, callback ssh.HostKeyCallback) (*AWSRunner, error) {
func NewAWSRunner(amiImageId, userFromCLIArg, host, region, availabilityZone, instanceID string, callback ssh.HostKeyCallback) (*AWSRunner, error) {
// Look up the AMI name of the runner
cfg, err := config.LoadDefaultConfig(context.Background())
imageName, err := getAMIName(amiImageId, region)
if err != nil {
return nil, err
}
cfg.Region = region
svc := ec2.NewFromConfig(cfg)
input := ec2.DescribeImagesInput{
ImageIds: []string{
amiImageId,
},
}
result, err := svc.DescribeImages(context.Background(), &input)
if err != nil {
return nil, err
}
if len(result.Images) != 1 {
return nil, fmt.Errorf("expected to find only one AMI")
}

// Lookup table for heuristically determining SSH username based on AMI
usernameLUT := []func(string) (bool, string){
func(_ string) (bool, string) { return os.Getenv("LW_SSH_USER") != "", os.Getenv("LW_SSH_USER") },
func(imageName string) (bool, string) { return strings.Contains(imageName, "ubuntu"), "ubuntu" },
func(imageName string) (bool, string) {
return strings.Contains(imageName, "amazon_linux"), "amazon_linux"
},
func(imageName string) (bool, string) { return strings.Contains(imageName, "amzn2-ami"), "amzn2-ami" },
}

// Heuristically assign SSH username based on AMI name
user := ""
imageName := *result.Images[0].Name // this array is guaranteed to have length 1
for _, matchFn := range usernameLUT {
if match, foundName := matchFn(imageName); match {
user = foundName
break
}
}
if user == "" { // no matching AMI found, return an error
return nil, fmt.Errorf("expected either Ubuntu or Amazon Linux 2 AMI, got AMI %s", imageName)
detectedUsername, err := getSSHUsername(userFromCLIArg, imageName)
if err != nil {
return nil, err
}

defaultCallback, err := DefaultKnownHosts()
if err == nil && callback == nil {
callback = defaultCallback
}

runner := New(user, host, callback)
runner := New(detectedUsername, host, callback)

return &AWSRunner{
*runner,
Expand Down Expand Up @@ -144,3 +113,65 @@ func (run AWSRunner) SendPublicKey(pubBytes []byte) error {

return nil
}

// getAMIName takes an AMI image ID and an AWS region name as input
// and calls the AWS API to get the name of the AMI. Returns the AMI
// name or an error if unsuccessful.
func getAMIName(amiImageId, region string) (string, error) {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return "", err
}
cfg.Region = region
svc := ec2.NewFromConfig(cfg)
input := ec2.DescribeImagesInput{
ImageIds: []string{
amiImageId,
},
}
result, err := svc.DescribeImages(context.Background(), &input)
if err != nil {
return "", err
}
if len(result.Images) != 1 {
return "", fmt.Errorf("expected to find only one AMI, instead found %v", result.Images)
}

return *result.Images[0].Name, nil
}

// getSSHUsername takes any username passed as a CLI arg,
// an AMI image name, a shell environment, and returns
// the username for SSHing into the AWS runner or the empty
// string and an error if the AMI is not supported.
// It first checks if `LW_SSH_USER` is set and returns it if so.
// Then it checks the AMI image name to heuristically determine the
// SSH username.
func getSSHUsername(userFromCLIArg, imageName string) (string, error) {
if userFromCLIArg != "" { // from CLI arg
return userFromCLIArg, nil
}
usernameLUT := getSSHUsernameLookupTable()
for _, matchFn := range usernameLUT {
if match, foundName := matchFn(imageName); match {
return foundName, nil
}
}
// No matching AMI found, return an error
return "", fmt.Errorf("no SSH username found for AMI %s, set as arg or shell env", imageName)
}

// getSSHUsernameLookupTable returns a lookup table for heuristically
// determining SSH username based on AMI.
// The first row of the table it returns is a function that checks
// `LW_SSH_USER` in the shell environment.
func getSSHUsernameLookupTable() []func(string) (bool, string) {
return []func(string) (bool, string){
func(_ string) (bool, string) { return os.Getenv("LW_SSH_USER") != "", os.Getenv("LW_SSH_USER") }, // THIS ROW MUST BE FIRST IN THE TABLE
func(imageName string) (bool, string) { return strings.Contains(imageName, "ubuntu"), "ubuntu" },
func(imageName string) (bool, string) {
return strings.Contains(imageName, "amazon_linux"), "amazon_linux"
},
func(imageName string) (bool, string) { return strings.Contains(imageName, "amzn2-ami"), "amzn2-ami" },
}
}
49 changes: 49 additions & 0 deletions lwrunner/awsrunner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// Author:: Nicholas Schmeller (<[email protected]>)
// Copyright:: Copyright 2022, Lacework Inc.
// License:: Apache License, Version 2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

package lwrunner

import (
"testing"

"github.com/stretchr/testify/assert"
)

const GOOD_SSH_USERNAME = "customer_ssh_username"

func TestSSHUsernameLookupChecksForCLIArgUsernameFirst(t *testing.T) {
user, err := getSSHUsername(GOOD_SSH_USERNAME, "some_ubuntu_ami")
assert.NoError(t, err)
assert.Equal(t, user, GOOD_SSH_USERNAME)
assert.NotEqual(t, user, "ubuntu")
}

func TestSSHUsernameLookupChecksEnvBeforeAMI(t *testing.T) {
t.Setenv("LW_SSH_USER", GOOD_SSH_USERNAME)

user, err := getSSHUsername("", "some_ubuntu_ami")
assert.NoError(t, err)
assert.Equal(t, user, GOOD_SSH_USERNAME)
assert.NotEqual(t, user, "ubuntu")
}

func TestSSHUsernameLookupFailsOnBadImageName(t *testing.T) {
user, err := getSSHUsername("", "ami_bad_image_name")
assert.Error(t, err)
assert.Empty(t, user)
}

0 comments on commit d216771

Please sign in to comment.