Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update IamAwsProvider as per minio-go implementation #1437

Merged
merged 7 commits into from
Aug 22, 2024

Conversation

setu4993
Copy link
Contributor

@setu4993 setu4993 commented Aug 20, 2024

This PR updates the IamAwsProvider to use IMDSv2, which is automatically enabled for all AWS EC2 / ECS instances.

I'm using this on an internal project successfully since ~a week, so contributing it back upstream.

Closes #1411.

"""Fetch credentials from EC2/ECS. """

res = _urlopen(self._http_client, "GET", url)
def fetch(self, url: str, headers: dict) -> Credentials:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @balamurugana . Can you please elaborate what that looks like? First time contributor here and didn't see a note about it in the contributing doc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying we should use the same logic as in minio-go. Below is the actual implementation

diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py
index 51dca1f..8790054 100644
--- a/minio/credentials/providers.py
+++ b/minio/credentials/providers.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+# pylint: disable=too-many-branches
+
 """Credential providers."""
 
 from __future__ import annotations
@@ -29,7 +31,7 @@ from abc import ABCMeta, abstractmethod
 from datetime import timedelta
 from pathlib import Path
 from typing import Callable, cast
-from urllib.parse import urlencode, urlsplit
+from urllib.parse import urlencode, urlsplit, urlunsplit
 from xml.etree import ElementTree as ET
 
 import certifi
@@ -42,7 +44,7 @@ except ImportError:
 
 from urllib3.util import Retry, parse_url
 
-from minio.helpers import sha256_hash
+from minio.helpers import sha256_hash, url_replace
 from minio.signer import sign_v4_sts
 from minio.time import from_iso8601utc, to_amz_date, utcnow
 from minio.xml import find, findtext
@@ -381,6 +383,13 @@ class IamAwsProvider(Provider):
             self,
             custom_endpoint: str | None = None,
             http_client: PoolManager | None = None,
+            auth_token: str | None = None,
+            relative_uri: str | None = None,
+            full_uri: str | None = None,
+            token_file: str | None = None,
+            role_arn: str | None = None,
+            role_session_name: str | None = None,
+            region: str | None = None,
     ):
         self._custom_endpoint = custom_endpoint
         self._http_client = http_client or PoolManager(
@@ -390,22 +399,42 @@ class IamAwsProvider(Provider):
                 status_forcelist=[500, 502, 503, 504],
             ),
         )
-        self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
-        self._aws_region = os.environ.get("AWS_REGION")
-        self._role_arn = os.environ.get("AWS_ROLE_ARN")
-        self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME")
-        self._relative_uri = os.environ.get(
-            "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
+        self._token = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or
+            auth_token
+        )
+        self._token_file = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or
+            auth_token
+        )
+        self._identity_file = (
+            os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file
+        )
+        self._aws_region = os.environ.get("AWS_REGION") or region
+        self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn
+        self._role_session_name = (
+            os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name
+        )
+        self._relative_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or
+            relative_uri
         )
         if self._relative_uri and not self._relative_uri.startswith("/"):
             self._relative_uri = "/" + self._relative_uri
-        self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
+        self._full_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or
+            full_uri
+        )
         self._credentials: Credentials | None = None
 
-    def fetch(self, url: str) -> Credentials:
+    def fetch(
+            self,
+            url: str,
+            headers: dict[str, str | list[str] | tuple[str]] | None = None,
+    ) -> Credentials:
         """Fetch credentials from EC2/ECS. """
 
-        res = _urlopen(self._http_client, "GET", url)
+        res = _urlopen(self._http_client, "GET", url, headers=headers)
         data = json.loads(res.data)
         if data.get("Code", "Success") != "Success":
             raise ValueError(
@@ -428,14 +457,17 @@ class IamAwsProvider(Provider):
             return self._credentials
 
         url = self._custom_endpoint
-        if self._token_file:
+
+        if self._identity_file:
             if not url:
                 url = "https://sts.amazonaws.com"
                 if self._aws_region:
                     url = f"https://sts.{self._aws_region}.amazonaws.com"
+                    if self._aws_region.startswith("cn-"):
+                        url += ".cn"
 
             provider = WebIdentityProvider(
-                lambda: _get_jwt_token(cast(str, self._token_file)),
+                lambda: _get_jwt_token(cast(str, self._identity_file)),
                 url,
                 role_arn=self._role_arn,
                 role_session_name=self._role_session_name,
@@ -444,21 +476,48 @@ class IamAwsProvider(Provider):
             self._credentials = provider.retrieve()
             return cast(Credentials, self._credentials)
 
+        headers: dict[str, str | list[str] | tuple[str]] | None = None
         if self._relative_uri:
             if not url:
                 url = "http://169.254.170.2" + self._relative_uri
+            headers = {"Authorization": self._token} if self._token else None
         elif self._full_uri:
-            if not url:
+            token = self._token
+            if self._token_file:
                 url = self._full_uri
-            _check_loopback_host(url)
+                with open(self._token_file, encoding="utf-8") as file:
+                    token = file.read()
+            else:
+                if not url:
+                    url = self._full_uri
+                    _check_loopback_host(url)
+            headers = {"Authorization": token} if token else None
         else:
             if not url:
-                url = (
-                    "http://169.254.169.254" +
-                    "/latest/meta-data/iam/security-credentials/"
-                )
-
-            res = _urlopen(self._http_client, "GET", url)
+                url = "http://169.254.169.254"
+
+            # Get IMDS Token
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                url+"/latest/api/token",
+                headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
+            )
+            token = res.data.decode("utf-8")
+            headers = {"X-aws-ec2-metadata-token": token} if token else None
+
+            # Get role name
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                urlunsplit(
+                    url_replace(
+                        urlsplit(url),
+                        path="/latest/meta-data/iam/security-credentials/",
+                    ),
+                ),
+                headers=headers,
+            )
             role_names = res.data.decode("utf-8").split("\n")
             if not role_names:
                 raise ValueError(f"no IAM roles attached to EC2 service {url}")
@@ -467,7 +526,7 @@ class IamAwsProvider(Provider):
         if not url:
             raise ValueError("url is empty; this should not happen")
 
-        self._credentials = self.fetch(url)
+        self._credentials = self.fetch(url, headers)
         return self._credentials
 
 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, thanks for the suggestion and helping me finish the PR :).

I didn't realize this package had drifted so far away from the Go implementation.

"""Fetch credentials from EC2/ECS. """

res = _urlopen(self._http_client, "GET", url)
def fetch(self, url: str, headers: dict) -> Credentials:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying we should use the same logic as in minio-go. Below is the actual implementation

diff --git a/minio/credentials/providers.py b/minio/credentials/providers.py
index 51dca1f..8790054 100644
--- a/minio/credentials/providers.py
+++ b/minio/credentials/providers.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+# pylint: disable=too-many-branches
+
 """Credential providers."""
 
 from __future__ import annotations
@@ -29,7 +31,7 @@ from abc import ABCMeta, abstractmethod
 from datetime import timedelta
 from pathlib import Path
 from typing import Callable, cast
-from urllib.parse import urlencode, urlsplit
+from urllib.parse import urlencode, urlsplit, urlunsplit
 from xml.etree import ElementTree as ET
 
 import certifi
@@ -42,7 +44,7 @@ except ImportError:
 
 from urllib3.util import Retry, parse_url
 
-from minio.helpers import sha256_hash
+from minio.helpers import sha256_hash, url_replace
 from minio.signer import sign_v4_sts
 from minio.time import from_iso8601utc, to_amz_date, utcnow
 from minio.xml import find, findtext
@@ -381,6 +383,13 @@ class IamAwsProvider(Provider):
             self,
             custom_endpoint: str | None = None,
             http_client: PoolManager | None = None,
+            auth_token: str | None = None,
+            relative_uri: str | None = None,
+            full_uri: str | None = None,
+            token_file: str | None = None,
+            role_arn: str | None = None,
+            role_session_name: str | None = None,
+            region: str | None = None,
     ):
         self._custom_endpoint = custom_endpoint
         self._http_client = http_client or PoolManager(
@@ -390,22 +399,42 @@ class IamAwsProvider(Provider):
                 status_forcelist=[500, 502, 503, 504],
             ),
         )
-        self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
-        self._aws_region = os.environ.get("AWS_REGION")
-        self._role_arn = os.environ.get("AWS_ROLE_ARN")
-        self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME")
-        self._relative_uri = os.environ.get(
-            "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
+        self._token = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or
+            auth_token
+        )
+        self._token_file = (
+            os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or
+            auth_token
+        )
+        self._identity_file = (
+            os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file
+        )
+        self._aws_region = os.environ.get("AWS_REGION") or region
+        self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn
+        self._role_session_name = (
+            os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name
+        )
+        self._relative_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or
+            relative_uri
         )
         if self._relative_uri and not self._relative_uri.startswith("/"):
             self._relative_uri = "/" + self._relative_uri
-        self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
+        self._full_uri = (
+            os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or
+            full_uri
+        )
         self._credentials: Credentials | None = None
 
-    def fetch(self, url: str) -> Credentials:
+    def fetch(
+            self,
+            url: str,
+            headers: dict[str, str | list[str] | tuple[str]] | None = None,
+    ) -> Credentials:
         """Fetch credentials from EC2/ECS. """
 
-        res = _urlopen(self._http_client, "GET", url)
+        res = _urlopen(self._http_client, "GET", url, headers=headers)
         data = json.loads(res.data)
         if data.get("Code", "Success") != "Success":
             raise ValueError(
@@ -428,14 +457,17 @@ class IamAwsProvider(Provider):
             return self._credentials
 
         url = self._custom_endpoint
-        if self._token_file:
+
+        if self._identity_file:
             if not url:
                 url = "https://sts.amazonaws.com"
                 if self._aws_region:
                     url = f"https://sts.{self._aws_region}.amazonaws.com"
+                    if self._aws_region.startswith("cn-"):
+                        url += ".cn"
 
             provider = WebIdentityProvider(
-                lambda: _get_jwt_token(cast(str, self._token_file)),
+                lambda: _get_jwt_token(cast(str, self._identity_file)),
                 url,
                 role_arn=self._role_arn,
                 role_session_name=self._role_session_name,
@@ -444,21 +476,48 @@ class IamAwsProvider(Provider):
             self._credentials = provider.retrieve()
             return cast(Credentials, self._credentials)
 
+        headers: dict[str, str | list[str] | tuple[str]] | None = None
         if self._relative_uri:
             if not url:
                 url = "http://169.254.170.2" + self._relative_uri
+            headers = {"Authorization": self._token} if self._token else None
         elif self._full_uri:
-            if not url:
+            token = self._token
+            if self._token_file:
                 url = self._full_uri
-            _check_loopback_host(url)
+                with open(self._token_file, encoding="utf-8") as file:
+                    token = file.read()
+            else:
+                if not url:
+                    url = self._full_uri
+                    _check_loopback_host(url)
+            headers = {"Authorization": token} if token else None
         else:
             if not url:
-                url = (
-                    "http://169.254.169.254" +
-                    "/latest/meta-data/iam/security-credentials/"
-                )
-
-            res = _urlopen(self._http_client, "GET", url)
+                url = "http://169.254.169.254"
+
+            # Get IMDS Token
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                url+"/latest/api/token",
+                headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
+            )
+            token = res.data.decode("utf-8")
+            headers = {"X-aws-ec2-metadata-token": token} if token else None
+
+            # Get role name
+            res = _urlopen(
+                self._http_client,
+                "GET",
+                urlunsplit(
+                    url_replace(
+                        urlsplit(url),
+                        path="/latest/meta-data/iam/security-credentials/",
+                    ),
+                ),
+                headers=headers,
+            )
             role_names = res.data.decode("utf-8").split("\n")
             if not role_names:
                 raise ValueError(f"no IAM roles attached to EC2 service {url}")
@@ -467,7 +526,7 @@ class IamAwsProvider(Provider):
         if not url:
             raise ValueError("url is empty; this should not happen")
 
-        self._credentials = self.fetch(url)
+        self._credentials = self.fetch(url, headers)
         return self._credentials
 
 

# Get IMDS Token
res = _urlopen(
self._http_client,
"PUT",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@balamurugana : This line differs from your suggestion because this needs to be a PUT request, not GET.

minio-go also has this as a PUT here.

@setu4993
Copy link
Contributor Author

Just pushed up the updates.

tests/unit/credentials_test.py Outdated Show resolved Hide resolved
@balamurugana balamurugana changed the title Support IMDSv2 credential fetch with AWS IAM provider update IamAwsProvider as per minio-go implementation Aug 22, 2024
@harshavardhana harshavardhana merged commit f673f09 into minio:master Aug 22, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IMDS V2 support for AWS
3 participants