Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

MSI timeout #131

Merged
merged 10 commits into from
Jun 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions msrestazure/azure_active_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from msrestazure.azure_cloud import AZURE_CHINA_CLOUD, AZURE_PUBLIC_CLOUD
from msrestazure.azure_configuration import AzureConfiguration
from msrestazure.azure_exceptions import MSIAuthenticationTimeoutError

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -544,6 +545,7 @@ class MSIAuthentication(BasicTokenAuthentication):

Optional kwargs may include:

- timeout: If provided, must be in seconds and indicates the maximum time we'll try to get a token before raising MSIAuthenticationTimeout
- client_id: Identifies, by Azure AD client id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with object_id and msi_res_id.
- object_id: Identifies, by Azure AD object id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with client_id and msi_res_id.
- msi_res_id: Identifies, by ARM resource id, a specific explicit identity to use when authenticating to Azure AD. Mutually exclusive with client_id and object_id.
Expand Down Expand Up @@ -571,7 +573,10 @@ def __init__(self, port=50342, **kwargs):
raise AuthenticationError("User Assigned Entity is not available on WebApp yet.")
elif "MSI_ENDPOINT" not in os.environ:
# Use IMDS if no MSI_ENDPOINT
self._vm_msi = _ImdsTokenProvider(self.msi_conf)
self._vm_msi = _ImdsTokenProvider(
self.msi_conf,
timeout=kwargs.get("timeout")
)
# Follow the same convention as all Credentials class to check for the token at creation time #106
self.set_token()

Expand Down Expand Up @@ -603,7 +608,7 @@ class _ImdsTokenProvider(object):
"""A help class handling token acquisitions through Azure IMDS plugin.
"""

def __init__(self, msi_conf=None):
def __init__(self, msi_conf=None, timeout=None):
self._user_agent = AzureConfiguration(None).user_agent
self.identity_type, self.identity_id = None, None
if msi_conf:
Expand All @@ -614,6 +619,7 @@ def __init__(self, msi_conf=None):
# default to system assigned identity on an empty configuration object

self.cache = {}
self.timeout = timeout

def get_token(self, resource):
import datetime
Expand All @@ -633,6 +639,21 @@ def get_token(self, resource):
self.cache[resource] = token_entry
return token_entry

def _sleep(self, time_to_wait, start_time):
"""Sleep for time_to_wait or time remaining until timeout reached.

:param float time: Time to sleep in seconds
:param float start_time: Absolute time where polling started
:rtype: bool
:returns: True if timeout was used
"""
if self.timeout is not None: # 0 is acceptable value, so we really want to test None
time_to_sleep = max(0, min(time_to_wait, start_time + self.timeout - time.time()))
else:
time_to_sleep = time_to_wait
time.sleep(time_to_sleep)
return time_to_sleep != time_to_wait

def _retrieve_token_from_imds_with_retry(self, resource):
import random
import json
Expand All @@ -648,21 +669,24 @@ def _retrieve_token_from_imds_with_retry(self, resource):
retry, max_retry, start_time = 1, 12, time.time()
# simplified version of https://en.wikipedia.org/wiki/Exponential_backoff
slots = [100 * ((2 << x) - 1) / 1000 for x in range(max_retry)]
has_timed_out = self.timeout == 0 # Assume a 0 timeout means "no more than one try"
while True:
result = requests.get(request_uri, params=payload, headers={'Metadata': 'true', 'User-Agent':self._user_agent})
_LOGGER.debug("MSI: Retrieving a token from %s, with payload %s", request_uri, payload)
if result.status_code in [404, 410, 429] or (499 < result.status_code < 600):
if retry <= max_retry:
if has_timed_out: # It was the last try, and we still don't get a good status code, die
raise MSIAuthenticationTimeoutError('MSI: Failed to acquired tokens before timeout {}'.format(self.timeout))
elif retry <= max_retry:
wait = random.choice(slots[:retry])
_LOGGER.warning("MSI: wait: %ss and retry: %s", wait, retry)
time.sleep(wait)
has_timed_out = self._sleep(wait, start_time)
retry += 1
else:
if result.status_code == 410: # For IMDS upgrading, we wait up to 70s
gap = 70 - (time.time() - start_time)
if gap > 0:
_LOGGER.warning("MSI: wait till 70 seconds when IMDS is upgrading")
time.sleep(gap)
has_timed_out = self._sleep(gap, start_time)
continue
break
elif result.status_code != 200:
Expand All @@ -671,7 +695,7 @@ def _retrieve_token_from_imds_with_retry(self, resource):
break

if result.status_code != 200:
raise TimeoutError('MSI: Failed to acquire tokens after {} times'.format(max_retry))
raise MSIAuthenticationTimeoutError('MSI: Failed to acquire tokens after {} times'.format(max_retry))

_LOGGER.debug('MSI: Token retrieved')
token_entry = json.loads(result.content.decode())
Expand Down
13 changes: 13 additions & 0 deletions msrestazure/azure_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
from msrest.serialization import Deserializer
from msrest.exceptions import DeserializationError

# TimeoutError for backward compat since it was used by former MSI code.
# but this never worked on Python 2.7, so Python 2.7 users get the correct one now
try:
class MSIAuthenticationTimeoutError(TimeoutError, ClientException):
"""If the MSI authentication reached the timeout without getting a token.
"""
pass
except NameError:
class MSIAuthenticationTimeoutError(ClientException):
"""If the MSI authentication reached the timeout without getting a token.
"""
pass

class CloudErrorRoot(object):
"""Just match the "error" key at the root of a OdataV4 JSON.
"""
Expand Down
82 changes: 71 additions & 11 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_msi_token_webapp
)
from msrestazure.azure_cloud import AZURE_CHINA_CLOUD
from msrestazure.azure_exceptions import MSIAuthenticationTimeoutError
from msrest.exceptions import TokenExpiredError, AuthenticationError

import pytest
Expand Down Expand Up @@ -512,37 +513,96 @@ def test_msi_vm(self):

@httpretty.activate
def test_msi_vm_imds_retry(self):

json_payload = {
'token_type': "TokenTypeIMDS",
"access_token": "AccessToken"
}
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=404)
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=429)
responses=[
httpretty.Response('', status=404),
httpretty.Response('', status=429),
httpretty.Response('', status=599),
httpretty.Response(body=json.dumps(json_payload)),
],
content_type="application/json")

credentials = MSIAuthentication()
assert credentials.scheme == "TokenTypeIMDS"
assert credentials.token == json_payload

# Assert four requests made only
assert len(httpretty.httpretty.latest_requests) == 4


@httpretty.activate
def test_msi_vm_imds_no_retry_on_bad_error(self):
"""Check that 499 throws immediatly."""
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=599)
status=499)
with self.assertRaises(HTTPError):
MSIAuthentication()

# Assert one request made only
assert len(httpretty.httpretty.latest_requests) == 1

@httpretty.activate
def test_msi_vm_imds_timeout_not_used(self):
"""Check that using timeout still allows a successfull scenario to pass."""
json_payload = {
'token_type': "TokenTypeIMDS",
"access_token": "AccessToken"
}
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
body=json.dumps(json_payload),
content_type="application/json")
credentials = MSIAuthentication()

credentials = MSIAuthentication(timeout=15)
assert credentials.scheme == "TokenTypeIMDS"
assert credentials.token == json_payload

@httpretty.activate
def test_msi_vm_imds_timeout_used(self):
"""Will loop on 410 until timeout is reached."""
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=410)

start_time = time.time()
with self.assertRaises(MSIAuthenticationTimeoutError):
MSIAuthentication(timeout=1)
# Test should take 1 second, but testing against 2 in case machine busy
assert time.time() - start_time < 2
# Assert at least two requests have been made
assert len(httpretty.httpretty.latest_requests) >= 2

@httpretty.activate
def test_msi_vm_imds_no_retry_on_bad_error(self):
def test_msi_vm_imds_timeout_zero_used(self):
"""If zero timeout, should do a try and fail immediatly."""
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=410)

with self.assertRaises(MSIAuthenticationTimeoutError):
MSIAuthentication(timeout=0)
# Assert one request made only
assert len(httpretty.httpretty.latest_requests) == 1

@unittest.skipIf(sys.version_info != (2,7), "TimeoutError doesn't exist in Py 2.7")
@httpretty.activate
def test_msi_vm_imds_timeout_used_timeouterror(self):
"""Will loop on 410 until timeout is reached."""
httpretty.register_uri(httpretty.GET,
'http://169.254.169.254/metadata/identity/oauth2/token',
status=499)
with self.assertRaises(HTTPError) as cm:
credentials = MSIAuthentication()
status=410)

# Verify that I can catch TimeoutError as well
with self.assertRaises(TimeoutError):
MSIAuthentication(timeout=1)
# Assert at two requests made only
assert len(httpretty.httpretty.latest_requests) >= 2


@pytest.mark.slow
Expand Down