diff --git a/tests/core/middleware/test_name_to_address_middleware.py b/tests/core/middleware/test_name_to_address_middleware.py index da8406ef4b..c40cfc7deb 100644 --- a/tests/core/middleware/test_name_to_address_middleware.py +++ b/tests/core/middleware/test_name_to_address_middleware.py @@ -8,9 +8,14 @@ InvalidAddress, ) from web3.middleware import ( # noqa: F401 + async_construct_fixture_middleware, + async_name_to_address_middleware, construct_fixture_middleware, name_to_address_middleware, ) +from web3.providers.async_base import ( + AsyncBaseProvider, +) from web3.providers.base import ( BaseProvider, ) @@ -57,3 +62,37 @@ def test_fail_name_resolver(w3): w3.middleware_onion.inject(return_chain_on_mainnet, layer=0) with pytest.raises(InvalidAddress, match=r".*ethereum\.eth.*"): w3.eth.get_balance("ethereum.eth") + +# --- async --- # + + +@pytest.fixture +def async_w3(): + async_w3 = Web3(provider=AsyncBaseProvider(), middlewares=[]) + async_w3.ens = TempENS({NAME: ADDRESS}) + async_w3.middleware_onion.add(async_name_to_address_middleware(async_w3)) + return async_w3 + +@pytest.mark.asyncio +async def test_async_pass_name_resolver(async_w3): + return_chain_on_mainnet = await async_construct_fixture_middleware( + { + "net_version": "1", + } + ) + return_balance = await async_construct_fixture_middleware({"eth_getBalance": BALANCE}) + async_w3.middleware_onion.inject(return_chain_on_mainnet, layer=0) + async_w3.middleware_onion.inject(return_balance, layer=0) + assert await async_w3.eth.get_balance(NAME) == BALANCE + + +@pytest.mark.asyncio +async def test_async_fail_name_resolver(async_w3): + return_chain_on_mainnet = async_construct_fixture_middleware( + { + "net_version": "2", + } + ) + async_w3.middleware_onion.inject(return_chain_on_mainnet, layer=0) + with pytest.raises(InvalidAddress, match=r".*ethereum\.eth.*"): + await async_w3.eth.get_balance("ethereum.eth") diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 7c0ddb8c1f..5a342adb39 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -42,6 +42,7 @@ local_filter_middleware, ) from .fixture import ( # noqa: F401 + async_construct_fixture_middleware, construct_error_generator_middleware, construct_fixture_middleware, construct_result_generator_middleware, @@ -58,6 +59,7 @@ geth_poa_middleware, ) from .names import ( # noqa: F401 + async_name_to_address_middleware, name_to_address_middleware, ) from .normalize_request_parameters import ( # noqa: F401 diff --git a/web3/middleware/fixture.py b/web3/middleware/fixture.py index 24a567e68e..48c451cd19 100644 --- a/web3/middleware/fixture.py +++ b/web3/middleware/fixture.py @@ -90,6 +90,27 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: # --- async --- # +async def async_construct_fixture_middleware(fixtures: Dict[RPCEndpoint, Any]) -> Middleware: + """ + Constructs a middleware which returns a static response for any method + which is found in the provided fixtures. + """ + + async def fixture_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" + ) -> Callable[[RPCEndpoint, Any], RPCResponse]: + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in fixtures: + result = fixtures[method] + return {"result": result} + else: + return await make_request(method, params) + + return middleware + + return fixture_middleware + + async def async_construct_result_generator_middleware( result_generators: Dict[RPCEndpoint, Any] ) -> Middleware: