Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update IamAwsProvider as per minio-go implementation #1437

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 81 additions & 26 deletions minio/credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=too-many-branches

"""Credential providers."""

from __future__ import annotations
Expand All @@ -29,7 +31,7 @@
from datetime import timedelta
from pathlib import Path
from typing import Callable, cast
from urllib.parse import urlencode, urlsplit
from urllib.parse import urlencode, urlsplit, urlunsplit
from xml.etree import ElementTree as ET

import certifi
Expand All @@ -42,7 +44,7 @@

from urllib3.util import Retry, parse_url

from minio.helpers import sha256_hash
from minio.helpers import sha256_hash, url_replace
from minio.signer import sign_v4_sts
from minio.time import from_iso8601utc, to_amz_date, utcnow
from minio.xml import find, findtext
Expand Down Expand Up @@ -381,6 +383,13 @@ def __init__(
self,
custom_endpoint: str | None = None,
http_client: PoolManager | None = None,
auth_token: str | None = None,
relative_uri: str | None = None,
full_uri: str | None = None,
token_file: str | None = None,
role_arn: str | None = None,
role_session_name: str | None = None,
region: str | None = None,
):
self._custom_endpoint = custom_endpoint
self._http_client = http_client or PoolManager(
Expand All @@ -390,22 +399,41 @@ def __init__(
status_forcelist=[500, 502, 503, 504],
),
)
self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
self._aws_region = os.environ.get("AWS_REGION")
self._role_arn = os.environ.get("AWS_ROLE_ARN")
self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME")
self._relative_uri = os.environ.get(
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
self._token = (
os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or
auth_token
)
self._token_file = (
os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or
auth_token
)
self._identity_file = (
os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file
)
self._aws_region = os.environ.get("AWS_REGION") or region
self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn
self._role_session_name = (
os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name
)
self._relative_uri = (
os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or
relative_uri
)
if self._relative_uri and not self._relative_uri.startswith("/"):
self._relative_uri = "/" + self._relative_uri
self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
self._full_uri = (
os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or
full_uri
)
self._credentials: Credentials | None = None

def fetch(self, url: str) -> Credentials:
"""Fetch credentials from EC2/ECS. """

res = _urlopen(self._http_client, "GET", url)
def fetch(
self,
url: str,
headers: dict[str, str | list[str] | tuple[str]] | None = None,
) -> Credentials:
"""Fetch credentials from EC2/ECS."""
res = _urlopen(self._http_client, "GET", url, headers=headers)
data = json.loads(res.data)
if data.get("Code", "Success") != "Success":
raise ValueError(
Expand All @@ -428,14 +456,16 @@ def retrieve(self) -> Credentials:
return self._credentials

url = self._custom_endpoint
if self._token_file:
if self._identity_file:
if not url:
url = "https://sts.amazonaws.com"
if self._aws_region:
url = f"https://sts.{self._aws_region}.amazonaws.com"
if self._aws_region.startswith("cn-"):
url += ".cn"

provider = WebIdentityProvider(
lambda: _get_jwt_token(cast(str, self._token_file)),
lambda: _get_jwt_token(cast(str, self._identity_file)),
url,
role_arn=self._role_arn,
role_session_name=self._role_session_name,
Expand All @@ -444,30 +474,55 @@ def retrieve(self) -> Credentials:
self._credentials = provider.retrieve()
return cast(Credentials, self._credentials)

headers: dict[str, str | list[str] | tuple[str]] | None = None
if self._relative_uri:
if not url:
url = "http://169.254.170.2" + self._relative_uri
headers = {"Authorization": self._token} if self._token else None
elif self._full_uri:
if not url:
token = self._token
if self._token_file:
url = self._full_uri
_check_loopback_host(url)
with open(self._token_file, encoding="utf-8") as file:
token = file.read()
else:
if not url:
url = self._full_uri
_check_loopback_host(url)
headers = {"Authorization": token} if token else None
else:
if not url:
url = (
"http://169.254.169.254" +
"/latest/meta-data/iam/security-credentials/"
)

res = _urlopen(self._http_client, "GET", url)
url = "http://169.254.169.254"

# Get IMDS Token
res = _urlopen(
self._http_client,
"PUT",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@balamurugana : This line differs from your suggestion because this needs to be a PUT request, not GET.

minio-go also has this as a PUT here.

url+"/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
)
token = res.data.decode("utf-8")
headers = {"X-aws-ec2-metadata-token": token} if token else None

# Get role name
res = _urlopen(
self._http_client,
"GET",
urlunsplit(
url_replace(
urlsplit(url),
path="/latest/meta-data/iam/security-credentials/",
),
),
headers=headers,
)
role_names = res.data.decode("utf-8").split("\n")
if not role_names:
raise ValueError(f"no IAM roles attached to EC2 service {url}")
url += "/" + role_names[0].strip("\r")

if not url:
raise ValueError("url is empty; this should not happen")

self._credentials = self.fetch(url)
self._credentials = self.fetch(url, headers=headers)
return self._credentials


Expand Down
35 changes: 0 additions & 35 deletions tests/unit/credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import unittest.mock as mock
from datetime import datetime, timedelta
from unittest import TestCase

from minio.credentials.credentials import Credentials
from minio.credentials.providers import (AWSConfigProvider, ChainedProvider,
EnvAWSProvider, EnvMinioProvider,
IamAwsProvider,
MinioClientConfigProvider,
StaticProvider)

Expand All @@ -45,36 +40,6 @@ def test_credentials_get(self):
self.assertEqual(creds.session_token, None)


class CredListResponse(object):
status = 200
data = b"test-s3-full-access-for-minio-ec2"


class CredsResponse(object):
status = 200
data = json.dumps({
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"Token": "token",
"Expiration": "2014-12-16T01:51:37Z",
"LastUpdated": "2009-11-23T0:00:00Z"
})


class IamAwsProviderTest(TestCase):
@mock.patch("urllib3.PoolManager.urlopen")
def test_iam(self, mock_connection):
mock_connection.side_effect = [CredListResponse(), CredsResponse()]
provider = IamAwsProvider()
creds = provider.retrieve()
self.assertEqual(creds.access_key, "accessKey")
self.assertEqual(creds.secret_key, "secret")
self.assertEqual(creds.session_token, "token")
self.assertEqual(creds._expiration, datetime(2014, 12, 16, 1, 51, 37))


class ChainedProviderTest(TestCase):
def test_chain_retrieve(self):
# clear environment
Expand Down
Loading