Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update IamAwsProvider as per minio-go implementation #1437

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions minio/credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,9 @@ def __init__(
self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
self._credentials: Credentials | None = None

def fetch(self, url: str) -> Credentials:
"""Fetch credentials from EC2/ECS. """

res = _urlopen(self._http_client, "GET", url)
def fetch(self, url: str, headers: dict) -> Credentials:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @balamurugana . Can you please elaborate what that looks like? First time contributor here and didn't see a note about it in the contributing doc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying we should use the same logic as in minio-go. Below is the actual implementation

diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py
index 51dca1f..8790054 100644
--- a/minio/credentials/providers.py
+++ b/minio/credentials/providers.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+# pylint: disable=too-many-branches
+
 """Credential providers."""
 
 from __future__ import annotations
@@ -29,7 +31,7 @@ from abc import ABCMeta, abstractmethod
 from datetime import timedelta
 from pathlib import Path
 from typing import Callable, cast
-from urllib.parse import urlencode, urlsplit
+from urllib.parse import urlencode, urlsplit, urlunsplit
 from xml.etree import ElementTree as ET
 
 import certifi
@@ -42,7 +44,7 @@ except ImportError:
 
 from urllib3.util import Retry, parse_url
 
-from minio.helpers import sha256_hash
+from minio.helpers import sha256_hash, url_replace
 from minio.signer import sign_v4_sts
 from minio.time import from_iso8601utc, to_amz_date, utcnow
 from minio.xml import find, findtext
@@ -381,6 +383,13 @@ class IamAwsProvider(Provider):
             self,
             custom_endpoint: str | None = None,
             http_client: PoolManager | None = None,
+            auth_token: str | None = None,
+            relative_uri: str | None = None,
+            full_uri: str | None = None,
+            token_file: str | None = None,
+            role_arn: str | None = None,
+            role_session_name: str | None = None,
+            region: str | None = None,
     ):
         self._custom_endpoint = custom_endpoint
         self._http_client = http_client or PoolManager(
@@ -390,22 +399,42 @@ class IamAwsProvider(Provider):
                 status_forcelist=[500, 502, 503, 504],
             ),
         )
-        self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
-        self._aws_region = os.environ.get("AWS_REGION")
-        self._role_arn = os.environ.get("AWS_ROLE_ARN")
-        self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME")
-        self._relative_uri = os.environ.get(
-            "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
+        self._token = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or
+            auth_token
+        )
+        self._token_file = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or
+            auth_token
+        )
+        self._identity_file = (
+            os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file
+        )
+        self._aws_region = os.environ.get("AWS_REGION") or region
+        self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn
+        self._role_session_name = (
+            os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name
+        )
+        self._relative_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or
+            relative_uri
         )
         if self._relative_uri and not self._relative_uri.startswith("/"):
             self._relative_uri = "/" + self._relative_uri
-        self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
+        self._full_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or
+            full_uri
+        )
         self._credentials: Credentials | None = None
 
-    def fetch(self, url: str) -> Credentials:
+    def fetch(
+            self,
+            url: str,
+            headers: dict[str, str | list[str] | tuple[str]] | None = None,
+    ) -> Credentials:
         """Fetch credentials from EC2/ECS. """
 
-        res = _urlopen(self._http_client, "GET", url)
+        res = _urlopen(self._http_client, "GET", url, headers=headers)
         data = json.loads(res.data)
         if data.get("Code", "Success") != "Success":
             raise ValueError(
@@ -428,14 +457,17 @@ class IamAwsProvider(Provider):
             return self._credentials
 
         url = self._custom_endpoint
-        if self._token_file:
+
+        if self._identity_file:
             if not url:
                 url = "https://sts.amazonaws.com"
                 if self._aws_region:
                     url = f"https://sts.{self._aws_region}.amazonaws.com"
+                    if self._aws_region.startswith("cn-"):
+                        url += ".cn"
 
             provider = WebIdentityProvider(
-                lambda: _get_jwt_token(cast(str, self._token_file)),
+                lambda: _get_jwt_token(cast(str, self._identity_file)),
                 url,
                 role_arn=self._role_arn,
                 role_session_name=self._role_session_name,
@@ -444,21 +476,48 @@ class IamAwsProvider(Provider):
             self._credentials = provider.retrieve()
             return cast(Credentials, self._credentials)
 
+        headers: dict[str, str | list[str] | tuple[str]] | None = None
         if self._relative_uri:
             if not url:
                 url = "http://169.254.170.2" + self._relative_uri
+            headers = {"Authorization": self._token} if self._token else None
         elif self._full_uri:
-            if not url:
+            token = self._token
+            if self._token_file:
                 url = self._full_uri
-            _check_loopback_host(url)
+                with open(self._token_file, encoding="utf-8") as file:
+                    token = file.read()
+            else:
+                if not url:
+                    url = self._full_uri
+                    _check_loopback_host(url)
+            headers = {"Authorization": token} if token else None
         else:
             if not url:
-                url = (
-                    "http://169.254.169.254" +
-                    "/latest/meta-data/iam/security-credentials/"
-                )
-
-            res = _urlopen(self._http_client, "GET", url)
+                url = "http://169.254.169.254"
+
+            # Get IMDS Token
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                url+"/latest/api/token",
+                headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
+            )
+            token = res.data.decode("utf-8")
+            headers = {"X-aws-ec2-metadata-token": token} if token else None
+
+            # Get role name
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                urlunsplit(
+                    url_replace(
+                        urlsplit(url),
+                        path="/latest/meta-data/iam/security-credentials/",
+                    ),
+                ),
+                headers=headers,
+            )
             role_names = res.data.decode("utf-8").split("\n")
             if not role_names:
                 raise ValueError(f"no IAM roles attached to EC2 service {url}")
@@ -467,7 +526,7 @@ class IamAwsProvider(Provider):
         if not url:
             raise ValueError("url is empty; this should not happen")
 
-        self._credentials = self.fetch(url)
+        self._credentials = self.fetch(url, headers)
         return self._credentials
 
 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, thanks for the suggestion and helping me finish the PR :).

I didn't realize this package had drifted so far away from the Go implementation.

"""Fetch credentials from EC2/ECS."""
res = _urlopen(self._http_client, "GET", url, headers=headers)
data = json.loads(res.data)
if data.get("Code", "Success") != "Success":
raise ValueError(
Expand Down Expand Up @@ -457,17 +456,27 @@ def retrieve(self) -> Credentials:
"http://169.254.169.254" +
"/latest/meta-data/iam/security-credentials/"
)

res = _urlopen(self._http_client, "GET", url)
# Step 1 of the IMDSv2 protocol: get a token from the metadata
# service with a 6-hour TTL.
response = _urlopen(
self._http_client, "PUT",
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}
)
# Step 2: get the role name from the metadata service, with the
# token as a header.
token_header: dict = {
"X-aws-ec2-metadata-token":
response.data.decode("utf-8").strip()
}
res = _urlopen(self._http_client, "GET", url, headers=token_header)
role_names = res.data.decode("utf-8").split("\n")
if not role_names:
raise ValueError(f"no IAM roles attached to EC2 service {url}")
url += "/" + role_names[0].strip("\r")

if not url:
raise ValueError("url is empty; this should not happen")

self._credentials = self.fetch(url)
self._credentials = self.fetch(url, headers=token_header)
return self._credentials


Expand Down
11 changes: 9 additions & 2 deletions tests/unit/credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def test_credentials_get(self):
self.assertEqual(creds.session_token, None)


class CredListResponse(object):
class TokenResponse(object):
harshavardhana marked this conversation as resolved.
Show resolved Hide resolved
status = 200
data = b"test-s3-full-access-for-minio-ec2"


class CredListResponse(object):
status = 200
data = b"test-s3-token-for-minio"


class CredsResponse(object):
status = 200
data = json.dumps({
Expand All @@ -66,7 +71,9 @@ class CredsResponse(object):
class IamAwsProviderTest(TestCase):
@mock.patch("urllib3.PoolManager.urlopen")
def test_iam(self, mock_connection):
mock_connection.side_effect = [CredListResponse(), CredsResponse()]
mock_connection.side_effect = [
TokenResponse(), CredListResponse(), CredsResponse()
]
provider = IamAwsProvider()
creds = provider.retrieve()
self.assertEqual(creds.access_key, "accessKey")
Expand Down
Loading