Skip to content

Commit

Permalink
Support IMDSv2 HttpTokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Steffey committed Nov 16, 2022
1 parent 87ac834 commit ddbe5d4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 24 deletions.
1 change: 1 addition & 0 deletions changelog/63067.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
salt-cloud support IMDSv2 tokens when using 'use-instance-role-credentials'
87 changes: 63 additions & 24 deletions salt/utils/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
:depends: requests
"""

import binascii
import hashlib
import hmac
import logging
import random
import re
import time
import urllib.parse
import xml.etree.ElementTree as ET
from datetime import datetime

import salt.config
Expand All @@ -30,6 +27,27 @@
except ImportError:
HAS_REQUESTS = False # pylint: disable=W0612

try:
import binascii

HAS_BINASCII = True # pylint: disable=W0612
except ImportError:
HAS_BINASCII = False # pylint: disable=W0612

try:
import urllib.parse

HAS_URLLIB = True # pylint: disable=W0612
except ImportError:
HAS_URLLIB = False # pylint: disable=W0612

try:
import xml.etree.ElementTree as ET

HAS_ETREE = True # pylint: disable=W0612
except ImportError:
HAS_ETREE = False # pylint: disable=W0612

# pylint: enable=import-error,redefined-builtin,no-name-in-module

log = logging.getLogger(__name__)
Expand All @@ -54,6 +72,7 @@
__Expiration__ = ""
__Location__ = ""
__AssumeCache__ = {}
__IMDS_Token__ = None


def sleep_exponential_backoff(attempts):
Expand All @@ -71,6 +90,44 @@ def sleep_exponential_backoff(attempts):
time.sleep(random.uniform(1, 2**attempts))


def get_metadata(path, refresh_token_if_needed=True):
"""
Get the instance metadata at the provided path
The path argument will be prepended by http://169.254.169.254/latest/
If using IMDSv2 with tokens required, the token will be fetched and used for subsequent requests
(unless refresh_token_if_needed is False, in which case this will fail if tokens are required
and no token was already cached)
"""
global __IMDS_Token__

headers = {}
if __IMDS_Token__ is not None:
headers["X-aws-ec2-metadata-token"] = __IMDS_Token__

# Connections to instance meta-data must fail fast and never be proxied
result = requests.get(
"http://169.254.169.254/latest/{}".format(path),
proxies={"http": ""},
headers=headers,
timeout=AWS_METADATA_TIMEOUT,
)

if result.status_code == 401 and refresh_token_if_needed:
# Probably using IMDSv2 with tokens required, so fetch token and retry
token_result = requests.put(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
__IMDS_Token__ = token_result.text
if token_result.ok:
return get_metadata(path, False)

result.raise_for_status()
return result


def creds(provider):
"""
Return the credentials for AWS signing. This could be just the id and key
Expand All @@ -95,27 +152,14 @@ def creds(provider):
return __AccessKeyId__, __SecretAccessKey__, __Token__
# We don't have any cached credentials, or they are expired, get them

# Connections to instance meta-data must fail fast and never be proxied
try:
result = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result.raise_for_status()
result = get_metadata("meta-data/iam/security-credentials/")
role = result.text
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
return provider["id"], provider["key"], ""

try:
result = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/{}".format(
role
),
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result.raise_for_status()
result = get_metadata("meta-data/iam/security-credentials/{}".format(role))
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
return provider["id"], provider["key"], ""

Expand Down Expand Up @@ -578,12 +622,7 @@ def get_region_from_metadata():
return __Location__

try:
# Connections to instance meta-data must fail fast and never be proxied
result = requests.get(
"http://169.254.169.254/latest/dynamic/instance-identity/document",
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result = get_metadata("dynamic/instance-identity/document")
except requests.exceptions.RequestException:
log.warning("Failed to get AWS region from instance metadata.", exc_info=True)
# Do not try again
Expand Down
52 changes: 52 additions & 0 deletions tests/pytests/unit/utils/test_aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
tests.unit.utils.aws_test
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Test the salt aws functions
"""

import io

import requests

from salt.utils.aws import get_metadata
from tests.support.mock import MagicMock, patch


def test_get_metadata_imdsv1():
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO(b"""test""")
with patch("requests.get", return_value=response):
result = get_metadata("/")
assert result.text == "test"


def test_get_metadata_imdsv2():
mock_token = "abc123"

def handle_get_mock(_, **args):
response = requests.Response()
if (
"X-aws-ec2-metadata-token" in args["headers"]
and args["headers"]["X-aws-ec2-metadata-token"] == mock_token
):
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO(b"""test""")
else:
response.status_code = 401
response.reason = "Unauthorized"
return response

put_response = requests.Response()
put_response.status_code = 200
put_response.reason = "OK"
put_response.raw = io.BytesIO(mock_token.encode("utf-8"))

with patch("requests.get", MagicMock(side_effect=handle_get_mock)), patch(
"requests.put", return_value=put_response
):
result = get_metadata("/")
assert result.text == "test"

0 comments on commit ddbe5d4

Please sign in to comment.