forked from remind101/assume-role
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
173 lines (142 loc) · 3.61 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"os/exec"
"strings"
"syscall"
"gopkg.in/yaml.v2"
)
var configFilePath = fmt.Sprintf("%s/.aws/roles", os.Getenv("HOME"))
func usage() {
fmt.Print(`Usage: assume-role <role> [<command> <args...>]
`)
}
func main() {
if len(os.Args) < 2 {
usage()
os.Exit(1)
}
role := os.Args[1]
args := os.Args[2:]
config, err := loadConfig()
must(err)
roleConfig, ok := config[role]
if !ok {
must(fmt.Errorf("%s not in ~/.aws/roles", role))
}
if os.Getenv("ASSUMED_ROLE") != "" {
// Clear out any previously set AWS_ environment variables so
// they aren't used with the assumeRole command.
cleanEnv()
}
creds, err := assumeRole(roleConfig.Role, roleConfig.MFA)
must(err)
if len(args) == 0 {
printCredentials(role, creds)
return
}
err = execWithCredentials(args, creds)
must(err)
}
func cleanEnv() {
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_SESSION_TOKEN")
os.Unsetenv("AWS_SECURITY_TOKEN")
}
func execWithCredentials(argv []string, creds *credentials) error {
argv0, err := exec.LookPath(argv[0])
if err != nil {
return err
}
os.Setenv("AWS_ACCESS_KEY_ID", creds.AccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey)
os.Setenv("AWS_SESSION_TOKEN", creds.SessionToken)
os.Setenv("AWS_SECURITY_TOKEN", creds.SessionToken)
env := os.Environ()
return syscall.Exec(argv0, argv, env)
}
type credentials struct {
AccessKeyID string
SecretAccessKey string
SessionToken string
}
// printCredentials prints the credentials in a way that can easily be sourced
// with bash.
func printCredentials(role string, creds *credentials) {
fmt.Printf("export AWS_ACCESS_KEY_ID=\"%s\"\n", creds.AccessKeyID)
fmt.Printf("export AWS_SECRET_ACCESS_KEY=\"%s\"\n", creds.SecretAccessKey)
fmt.Printf("export AWS_SESSION_TOKEN=\"%s\"\n", creds.SessionToken)
fmt.Printf("export AWS_SECURITY_TOKEN=\"%s\"\n", creds.SessionToken)
fmt.Printf("export ASSUMED_ROLE=\"%s\"\n", role)
fmt.Printf("# Run this to configure your shell:\n")
fmt.Printf("# eval $(%s)\n", strings.Join(os.Args, " "))
}
// assumeRole assumes the given role and returns the temporary STS credentials.
func assumeRole(role, mfa string) (*credentials, error) {
args := []string{
"sts",
"assume-role",
"--output", "json",
"--role-arn", role,
"--role-session-name", "cli",
}
if mfa != "" {
args = append(args,
"--serial-number", mfa,
"--token-code",
readTokenCode(),
)
}
b := new(bytes.Buffer)
cmd := exec.Command("aws", args...)
cmd.Stdout = b
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, err
}
if err := cmd.Wait(); err != nil {
return nil, err
}
var resp struct{ Credentials credentials }
if err := json.NewDecoder(b).Decode(&resp); err != nil {
return nil, err
}
return &resp.Credentials, nil
}
type roleConfig struct {
Role string `yaml:"role"`
MFA string `yaml:"mfa"`
}
type config map[string]roleConfig
// readTokenCode reads the MFA token from Stdin.
func readTokenCode() string {
r := bufio.NewReader(os.Stdin)
fmt.Fprintf(os.Stderr, "MFA code: ")
text, _ := r.ReadString('\n')
return strings.TrimSpace(text)
}
// loadConfig loads the ~/.aws/roles file.
func loadConfig() (config, error) {
raw, err := ioutil.ReadFile(configFilePath)
if err != nil {
return nil, err
}
config := make(config)
return config, yaml.Unmarshal(raw, &config)
}
func must(err error) {
if err != nil {
if _, ok := err.(*exec.ExitError); ok {
// Errors are already on Stderr.
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
}