diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py index 51dca1f1a..197d32987 100644 --- a/minio/credentials/providers.py +++ b/minio/credentials/providers.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-branches + """Credential providers.""" from __future__ import annotations @@ -29,7 +31,7 @@ from datetime import timedelta from pathlib import Path from typing import Callable, cast -from urllib.parse import urlencode, urlsplit +from urllib.parse import urlencode, urlsplit, urlunsplit from xml.etree import ElementTree as ET import certifi @@ -42,7 +44,7 @@ from urllib3.util import Retry, parse_url -from minio.helpers import sha256_hash +from minio.helpers import sha256_hash, url_replace from minio.signer import sign_v4_sts from minio.time import from_iso8601utc, to_amz_date, utcnow from minio.xml import find, findtext @@ -381,6 +383,13 @@ def __init__( self, custom_endpoint: str | None = None, http_client: PoolManager | None = None, + auth_token: str | None = None, + relative_uri: str | None = None, + full_uri: str | None = None, + token_file: str | None = None, + role_arn: str | None = None, + role_session_name: str | None = None, + region: str | None = None, ): self._custom_endpoint = custom_endpoint self._http_client = http_client or PoolManager( @@ -390,22 +399,41 @@ def __init__( status_forcelist=[500, 502, 503, 504], ), ) - self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") - self._aws_region = os.environ.get("AWS_REGION") - self._role_arn = os.environ.get("AWS_ROLE_ARN") - self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME") - self._relative_uri = os.environ.get( - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", + self._token = ( + os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or + auth_token + ) + self._token_file = ( + os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or + auth_token + ) + self._identity_file = ( + os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file + ) + self._aws_region = os.environ.get("AWS_REGION") or region + self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn + self._role_session_name = ( + os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name + ) + self._relative_uri = ( + os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or + relative_uri ) if self._relative_uri and not self._relative_uri.startswith("/"): self._relative_uri = "/" + self._relative_uri - self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + self._full_uri = ( + os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or + full_uri + ) self._credentials: Credentials | None = None - def fetch(self, url: str) -> Credentials: - """Fetch credentials from EC2/ECS. """ - - res = _urlopen(self._http_client, "GET", url) + def fetch( + self, + url: str, + headers: dict[str, str | list[str] | tuple[str]] | None = None, + ) -> Credentials: + """Fetch credentials from EC2/ECS.""" + res = _urlopen(self._http_client, "GET", url, headers=headers) data = json.loads(res.data) if data.get("Code", "Success") != "Success": raise ValueError( @@ -428,14 +456,16 @@ def retrieve(self) -> Credentials: return self._credentials url = self._custom_endpoint - if self._token_file: + if self._identity_file: if not url: url = "https://sts.amazonaws.com" if self._aws_region: url = f"https://sts.{self._aws_region}.amazonaws.com" + if self._aws_region.startswith("cn-"): + url += ".cn" provider = WebIdentityProvider( - lambda: _get_jwt_token(cast(str, self._token_file)), + lambda: _get_jwt_token(cast(str, self._identity_file)), url, role_arn=self._role_arn, role_session_name=self._role_session_name, @@ -444,30 +474,55 @@ def retrieve(self) -> Credentials: self._credentials = provider.retrieve() return cast(Credentials, self._credentials) + headers: dict[str, str | list[str] | tuple[str]] | None = None if self._relative_uri: if not url: url = "http://169.254.170.2" + self._relative_uri + headers = {"Authorization": self._token} if self._token else None elif self._full_uri: - if not url: + token = self._token + if self._token_file: url = self._full_uri - _check_loopback_host(url) + with open(self._token_file, encoding="utf-8") as file: + token = file.read() + else: + if not url: + url = self._full_uri + _check_loopback_host(url) + headers = {"Authorization": token} if token else None else: if not url: - url = ( - "http://169.254.169.254" + - "/latest/meta-data/iam/security-credentials/" - ) - - res = _urlopen(self._http_client, "GET", url) + url = "http://169.254.169.254" + + # Get IMDS Token + res = _urlopen( + self._http_client, + "PUT", + url+"/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + ) + token = res.data.decode("utf-8") + headers = {"X-aws-ec2-metadata-token": token} if token else None + + # Get role name + res = _urlopen( + self._http_client, + "GET", + urlunsplit( + url_replace( + urlsplit(url), + path="/latest/meta-data/iam/security-credentials/", + ), + ), + headers=headers, + ) role_names = res.data.decode("utf-8").split("\n") if not role_names: raise ValueError(f"no IAM roles attached to EC2 service {url}") url += "/" + role_names[0].strip("\r") - if not url: raise ValueError("url is empty; this should not happen") - - self._credentials = self.fetch(url) + self._credentials = self.fetch(url, headers=headers) return self._credentials diff --git a/tests/unit/credentials_test.py b/tests/unit/credentials_test.py index 4d3b1ee44..87d72d068 100644 --- a/tests/unit/credentials_test.py +++ b/tests/unit/credentials_test.py @@ -14,16 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os -import unittest.mock as mock -from datetime import datetime, timedelta from unittest import TestCase -from minio.credentials.credentials import Credentials from minio.credentials.providers import (AWSConfigProvider, ChainedProvider, EnvAWSProvider, EnvMinioProvider, - IamAwsProvider, MinioClientConfigProvider, StaticProvider) @@ -45,36 +40,6 @@ def test_credentials_get(self): self.assertEqual(creds.session_token, None) -class CredListResponse(object): - status = 200 - data = b"test-s3-full-access-for-minio-ec2" - - -class CredsResponse(object): - status = 200 - data = json.dumps({ - "Code": "Success", - "Type": "AWS-HMAC", - "AccessKeyId": "accessKey", - "SecretAccessKey": "secret", - "Token": "token", - "Expiration": "2014-12-16T01:51:37Z", - "LastUpdated": "2009-11-23T0:00:00Z" - }) - - -class IamAwsProviderTest(TestCase): - @mock.patch("urllib3.PoolManager.urlopen") - def test_iam(self, mock_connection): - mock_connection.side_effect = [CredListResponse(), CredsResponse()] - provider = IamAwsProvider() - creds = provider.retrieve() - self.assertEqual(creds.access_key, "accessKey") - self.assertEqual(creds.secret_key, "secret") - self.assertEqual(creds.session_token, "token") - self.assertEqual(creds._expiration, datetime(2014, 12, 16, 1, 51, 37)) - - class ChainedProviderTest(TestCase): def test_chain_retrieve(self): # clear environment