Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IMDSv2 HttpTokens #63067

Merged
merged 2 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/63067.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
salt-cloud support IMDSv2 tokens when using 'use-instance-role-credentials'
87 changes: 63 additions & 24 deletions salt/utils/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
:depends: requests
"""

import binascii
import hashlib
import hmac
import logging
import random
import re
import time
import urllib.parse
import xml.etree.ElementTree as ET
from datetime import datetime

import salt.config
Expand All @@ -30,6 +27,27 @@
except ImportError:
HAS_REQUESTS = False # pylint: disable=W0612

try:
import binascii

HAS_BINASCII = True # pylint: disable=W0612
except ImportError:
HAS_BINASCII = False # pylint: disable=W0612

try:
import urllib.parse

HAS_URLLIB = True # pylint: disable=W0612
except ImportError:
HAS_URLLIB = False # pylint: disable=W0612

try:
import xml.etree.ElementTree as ET

HAS_ETREE = True # pylint: disable=W0612
except ImportError:
HAS_ETREE = False # pylint: disable=W0612

# pylint: enable=import-error,redefined-builtin,no-name-in-module

log = logging.getLogger(__name__)
Expand All @@ -54,6 +72,7 @@
__Expiration__ = ""
__Location__ = ""
__AssumeCache__ = {}
__IMDS_Token__ = None


def sleep_exponential_backoff(attempts):
Expand All @@ -71,6 +90,44 @@ def sleep_exponential_backoff(attempts):
time.sleep(random.uniform(1, 2**attempts))


def get_metadata(path, refresh_token_if_needed=True):
"""
Get the instance metadata at the provided path
The path argument will be prepended by http://169.254.169.254/latest/
If using IMDSv2 with tokens required, the token will be fetched and used for subsequent requests
(unless refresh_token_if_needed is False, in which case this will fail if tokens are required
and no token was already cached)
"""
global __IMDS_Token__

headers = {}
if __IMDS_Token__ is not None:
headers["X-aws-ec2-metadata-token"] = __IMDS_Token__

# Connections to instance meta-data must fail fast and never be proxied
result = requests.get(
"http://169.254.169.254/latest/{}".format(path),
proxies={"http": ""},
headers=headers,
timeout=AWS_METADATA_TIMEOUT,
)

if result.status_code == 401 and refresh_token_if_needed:
Ch3LL marked this conversation as resolved.
Show resolved Hide resolved
# Probably using IMDSv2 with tokens required, so fetch token and retry
token_result = requests.put(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
__IMDS_Token__ = token_result.text
if token_result.ok:
return get_metadata(path, False)

result.raise_for_status()
return result


def creds(provider):
"""
Return the credentials for AWS signing. This could be just the id and key
Expand All @@ -95,27 +152,14 @@ def creds(provider):
return __AccessKeyId__, __SecretAccessKey__, __Token__
# We don't have any cached credentials, or they are expired, get them

# Connections to instance meta-data must fail fast and never be proxied
try:
result = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result.raise_for_status()
result = get_metadata("meta-data/iam/security-credentials/")
role = result.text
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
return provider["id"], provider["key"], ""

try:
result = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/{}".format(
role
),
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result.raise_for_status()
result = get_metadata("meta-data/iam/security-credentials/{}".format(role))
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
return provider["id"], provider["key"], ""

Expand Down Expand Up @@ -578,12 +622,7 @@ def get_region_from_metadata():
return __Location__

try:
# Connections to instance meta-data must fail fast and never be proxied
result = requests.get(
"http://169.254.169.254/latest/dynamic/instance-identity/document",
proxies={"http": ""},
timeout=AWS_METADATA_TIMEOUT,
)
result = get_metadata("dynamic/instance-identity/document")
except requests.exceptions.RequestException:
log.warning("Failed to get AWS region from instance metadata.", exc_info=True)
# Do not try again
Expand Down
52 changes: 52 additions & 0 deletions tests/pytests/unit/utils/test_aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
tests.unit.utils.aws_test
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Test the salt aws functions
"""

import io

import requests

from salt.utils.aws import get_metadata
from tests.support.mock import MagicMock, patch


def test_get_metadata_imdsv1():
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO(b"""test""")
with patch("requests.get", return_value=response):
result = get_metadata("/")
assert result.text == "test"


def test_get_metadata_imdsv2():
mock_token = "abc123"

def handle_get_mock(_, **args):
response = requests.Response()
if (
"X-aws-ec2-metadata-token" in args["headers"]
and args["headers"]["X-aws-ec2-metadata-token"] == mock_token
):
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO(b"""test""")
else:
response.status_code = 401
response.reason = "Unauthorized"
return response

put_response = requests.Response()
put_response.status_code = 200
put_response.reason = "OK"
put_response.raw = io.BytesIO(mock_token.encode("utf-8"))

with patch("requests.get", MagicMock(side_effect=handle_get_mock)), patch(
"requests.put", return_value=put_response
):
result = get_metadata("/")
assert result.text == "test"