diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py index 666c5ed63db6..df7e3dabea84 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py @@ -52,6 +52,7 @@ def emit(self, record): self.messages.append(record) +@pytest.mark.usefixtures("recorded_test", "variable_recorder") class TestKeyVaultKey(KeyVaultTestCase, KeysTestCase): def _assert_jwks_equal(self, jwk1, jwk2): @@ -175,7 +176,6 @@ def _to_bytes(hex): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_key_crud_operations(self, client, is_hsm, **kwargs): assert client is not None @@ -242,7 +242,6 @@ async def test_key_crud_operations(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_hsm) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_rsa_public_exponent(self, client, **kwargs): """The public exponent of a Managed HSM RSA key can be specified during creation""" assert client is not None @@ -255,7 +254,6 @@ async def test_rsa_public_exponent(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_backup_restore(self, client, is_hsm, **kwargs): assert client is not None @@ -283,7 +281,6 @@ async def test_backup_restore(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_key_list(self, client, is_hsm, **kwargs): assert client is not None @@ -307,7 +304,6 @@ async def test_key_list(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_list_versions(self, client, is_hsm, **kwargs): assert client is not None @@ -334,7 +330,6 @@ async def test_list_versions(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_list_deleted_keys(self, client, is_hsm, **kwargs): assert client is not None @@ -366,7 +361,6 @@ async def test_list_deleted_keys(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_recover(self, client, is_hsm, **kwargs): assert client is not None @@ -397,7 +391,6 @@ async def test_recover(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_purge(self, client, is_hsm, **kwargs): assert client is not None @@ -425,7 +418,6 @@ async def test_purge(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",logging_enabled) @AsyncKeysClientPreparer(logging_enable = True) - @recorded_by_proxy_async async def test_logging_enabled(self, client, is_hsm, **kwargs): mock_handler = MockHandler() @@ -461,7 +453,6 @@ async def test_logging_enabled(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",logging_disabled) @AsyncKeysClientPreparer(logging_enable = False) - @recorded_by_proxy_async async def test_logging_disabled(self, client, is_hsm, **kwargs): mock_handler = MockHandler() @@ -496,7 +487,6 @@ async def test_logging_disabled(self, client, is_hsm, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_hsm_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_get_random_bytes(self, client, **kwargs): assert client @@ -513,7 +503,6 @@ async def test_get_random_bytes(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_key_release(self, client, **kwargs): set_bodiless_matcher() attestation_uri = self._get_attestation_uri() @@ -534,7 +523,6 @@ async def test_key_release(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_hsm_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_imported_key_release(self, client, **kwargs): set_bodiless_matcher() attestation_uri = self._get_attestation_uri() @@ -555,7 +543,6 @@ async def test_imported_key_release(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_update_release_policy(self, client, **kwargs): set_bodiless_matcher() attestation_uri = self._get_attestation_uri() @@ -598,7 +585,6 @@ async def test_update_release_policy(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_immutable_release_policy(self, client, **kwargs): set_bodiless_matcher() attestation_uri = self._get_attestation_uri() @@ -633,7 +619,6 @@ async def test_immutable_release_policy(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_key_rotation(self, client, **kwargs): set_bodiless_matcher() if (not is_public_cloud() and self.is_live): @@ -651,7 +636,6 @@ async def test_key_rotation(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",only_vault_7_3) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_key_rotation_policy(self, client, **kwargs): set_bodiless_matcher() if (not is_public_cloud() and self.is_live): @@ -724,7 +708,6 @@ async def test_key_rotation_policy(self, client, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("api_version,is_hsm",all_api_versions) @AsyncKeysClientPreparer() - @recorded_by_proxy_async async def test_get_cryptography_client(self, client, is_hsm, **kwargs): key_name = self.get_resource_name("key-name") key = await self._create_rsa_key(client, key_name, hardware_protected=is_hsm) diff --git a/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py b/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py index 3ed645a6a8db..f37a63b95b8b 100644 --- a/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py +++ b/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from inspect import iscoroutinefunction import logging import requests import six @@ -24,7 +25,7 @@ from .proxy_startup import test_proxy if TYPE_CHECKING: - from typing import Any, Dict, Optional, Tuple + from typing import Any, Callable, Dict, Optional, Tuple from azure.core.pipeline.transport import HttpRequest # To learn about how to migrate SDK tests to the test proxy, please refer to the migration guide at @@ -126,7 +127,7 @@ def transform_request(request: "HttpRequest", recording_id: str) -> None: request.url = updated_target -def recorded_by_proxy(test_func) -> None: +def recorded_by_proxy(test_func: "Callable") -> None: """Decorator that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests. For more details and usage examples, refer to @@ -212,10 +213,10 @@ def start_proxy_session() -> "Optional[Tuple[str, str, Dict[str, str]]]": @pytest.fixture -def recorded_test(test_proxy, request) -> "Dict[str, Any]": - """Fixture that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests. +async def recorded_test(test_proxy: None, request: pytest.FixtureRequest) -> "Dict[str, Any]": + """Fixture that redirects network requests to target the azure-sdk-tools test proxy. - For more details and usage examples, refer to + Use with recorded tests. For more details and usage examples, refer to https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/test_proxy_migration_guide.md. :param function test_proxy: The fixture responsible for starting up the test proxy server. @@ -223,7 +224,67 @@ def recorded_test(test_proxy, request) -> "Dict[str, Any]": :yields: A dictionary containing information relevant to the currently executing test. """ + test_id, recording_id, variables = start_proxy_session() + + # True if the function requesting the fixture is an async test + if iscoroutinefunction(request._pyfuncitem.function): + original_transport_func = await redirect_async_traffic(recording_id) + yield {"variables": variables} # yield relevant test info and allow tests to run + restore_async_traffic(original_transport_func, request) + else: + original_transport_func = redirect_traffic(recording_id) + yield {"variables": variables} # yield relevant test info and allow tests to run + restore_traffic(original_transport_func, request) + + stop_record_or_playback(test_id, recording_id, variables) + + +async def redirect_async_traffic(recording_id: str) -> "Callable": + """Redirects asynchronous network requests to target the test proxy. + + :param str recording_id: Recording ID of the currently executing test. + + :returns: The original transport function used by the currently executing test. + """ + from azure.core.pipeline.transport import AioHttpTransport + + original_transport_func = AioHttpTransport.send + + def transform_args(*args, **kwargs): + copied_positional_args = list(args) + request = copied_positional_args[1] + + transform_request(request, recording_id) + + return tuple(copied_positional_args), kwargs + + async def combined_call(*args, **kwargs): + adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs) + result = await original_transport_func(*adjusted_args, **adjusted_kwargs) + + # make the x-recording-upstream-base-uri the URL of the request + # this makes the request look like it was made to the original endpoint instead of to the proxy + # without this, things like LROPollers can get broken by polling the wrong endpoint + parsed_result = url_parse.urlparse(result.request.url) + upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"]) + upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc} + original_target = parsed_result._replace(**upstream_uri_dict).geturl() + + result.request.url = original_target + return result + + AioHttpTransport.send = combined_call + return original_transport_func + + +def redirect_traffic(recording_id: str) -> "Callable": + """Redirects network requests to target the test proxy. + + :param str recording_id: Recording ID of the currently executing test. + + :returns: The original transport function used by the currently executing test. + """ original_transport_func = RequestsTransport.send def transform_args(*args, **kwargs): @@ -250,12 +311,20 @@ def combined_call(*args, **kwargs): return result RequestsTransport.send = combined_call + return original_transport_func - # store info pertinent to the test in a dictionary that other fixtures can access - test_info = {"variables": variables} - yield test_info # yield and allow test to run - RequestsTransport.send = original_transport_func # test finished running -- tear down +def restore_async_traffic(original_transport_func: "Callable", request: pytest.FixtureRequest) -> None: + """Resets asynchronous network traffic to no longer target the test proxy. + + :param original_transport_func: The original transport function used by the currently executing test. + :type original_transport_func: Callable + :param request: The built-in `request` pytest fixture. + :type request: ~pytest.FixtureRequest + """ + from azure.core.pipeline.transport import AioHttpTransport + + AioHttpTransport.send = original_transport_func # test finished running -- tear down if hasattr(request.node, "test_error"): # Exceptions are logged here instead of being raised because of how pytest handles error raising from inside @@ -270,11 +339,33 @@ def combined_call(*args, **kwargs): logger = logging.getLogger() logger.error(f"\n\n-----Test proxy playback error:-----\n\n{message}") - stop_record_or_playback(test_id, recording_id, variables) + +def restore_traffic(original_transport_func: "Callable", request: pytest.FixtureRequest) -> None: + """Resets network traffic to no longer target the test proxy. + + :param original_transport_func: The original transport function used by the currently executing test. + :type original_transport_func: Callable + :param request: The built-in `request` pytest fixture. + :type request: ~pytest.FixtureRequest + """ + RequestsTransport.send = original_transport_func # test finished running -- tear down + + if hasattr(request.node, "test_error"): + # Exceptions are logged here instead of being raised because of how pytest handles error raising from inside + # fixtures and hooks. Raising from a fixture raises an error in addition to the test failure report, and the + # test proxy error is logged before the test failure output (making it difficult to find in pytest output). + # Raising from a hook isn't allowed, and produces an internal error that disrupts test execution. + # ResourceNotFoundErrors during playback indicate a recording mismatch + error = request.node.test_error + if isinstance(error, ResourceNotFoundError): + error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response) + message = error_body.get("message") or error_body.get("Message") + logger = logging.getLogger() + logger.error(f"\n\n-----Test proxy playback error:-----\n\n{message}") @pytest.fixture -def variable_recorder(recorded_test) -> "Dict[str, str]": +def variable_recorder(recorded_test: "Dict[str, Any]") -> "Dict[str, str]": """Fixture that invokes the `recorded_test` fixture and returns a dictionary of recorded test variables. :param function recorded_test: The fixture responsible for redirecting network traffic to target the test proxy.