From 2c53132ca9bda7cd128f3b094a79150440b81d33 Mon Sep 17 00:00:00 2001 From: Felipe Selmo Date: Fri, 16 Jul 2021 13:04:25 -0600 Subject: [PATCH] formatting and validation middleware async support --- .../go_ethereum/test_goethereum_http.py | 4 +- web3/_utils/module_testing/eth_module.py | 48 +++++++ web3/eth.py | 15 ++- web3/manager.py | 6 +- web3/middleware/__init__.py | 1 + web3/middleware/formatting.py | 118 ++++++++++++------ web3/middleware/validation.py | 65 +++++++--- web3/tools/benchmark/main.py | 14 ++- 8 files changed, 202 insertions(+), 69 deletions(-) diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 793daffd6c..f20b74cb9e 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -10,6 +10,7 @@ from web3.middleware import ( async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_validation_middleware, ) from web3.net import ( AsyncNet, @@ -85,8 +86,9 @@ async def async_w3(geth_process, endpoint_uri): _web3 = Web3( AsyncHTTPProvider(endpoint_uri), middlewares=[ + async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, - async_buffered_gas_estimate_middleware + await async_validation_middleware, ], modules={'eth': (AsyncEth,), 'async_net': (AsyncNet,)}) return _web3 diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index dbd0effac6..ce00973015 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -45,6 +45,7 @@ NameNotFound, TransactionNotFound, TransactionTypeMismatch, + ValidationError, ) from web3.types import ( # noqa: F401 BlockData, @@ -261,6 +262,30 @@ async def test_eth_send_transaction_max_fee_less_than_tip( ): await async_w3.eth.send_transaction(txn_params) # type: ignore + @pytest.mark.asyncio + async def test_validation_middleware_chain_id_mismatch( + self, async_w3: "Web3", unlocked_account_dual_type: ChecksumAddress + ) -> None: + wrong_chain_id = 1234567890 + actual_chain_id = await async_w3.eth.chain_id # type: ignore + + txn_params: TxParams = { + 'from': unlocked_account_dual_type, + 'to': unlocked_account_dual_type, + 'value': Wei(1), + 'gas': Wei(21000), + 'maxFeePerGas': async_w3.toWei(2, 'gwei'), + 'maxPriorityFeePerGas': async_w3.toWei(1, 'gwei'), + 'chainId': wrong_chain_id, + + } + with pytest.raises( + ValidationError, + match=f'The transaction declared chain ID {wrong_chain_id}, ' + f'but the connected node is on {actual_chain_id}' + ): + await async_w3.eth.send_transaction(txn_params) # type: ignore + @pytest.mark.asyncio async def test_eth_send_raw_transaction(self, async_w3: "Web3") -> None: # private key 0x3c2ab4e8f17a7dea191b8c991522660126d681039509dc3bb31af7c9bdb63518 @@ -1518,6 +1543,29 @@ def test_eth_send_transaction_max_fee_less_than_tip( ): web3.eth.send_transaction(txn_params) + def test_validation_middleware_chain_id_mismatch( + self, web3: "Web3", unlocked_account_dual_type: ChecksumAddress + ) -> None: + wrong_chain_id = 1234567890 + actual_chain_id = web3.eth.chain_id + + txn_params: TxParams = { + 'from': unlocked_account_dual_type, + 'to': unlocked_account_dual_type, + 'value': Wei(1), + 'gas': Wei(21000), + 'maxFeePerGas': web3.toWei(2, 'gwei'), + 'maxPriorityFeePerGas': web3.toWei(1, 'gwei'), + 'chainId': wrong_chain_id, + + } + with pytest.raises( + ValidationError, + match=f'The transaction declared chain ID {wrong_chain_id}, ' + f'but the connected node is on {actual_chain_id}' + ): + web3.eth.send_transaction(txn_params) + @pytest.mark.parametrize( "max_fee", (1000000000, None), diff --git a/web3/eth.py b/web3/eth.py index 1502facdd6..5e1a80b614 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -116,6 +116,11 @@ class BaseEth(Module): mungers=None, ) + _chain_id: Method[Callable[[], int]] = Method( + RPC.eth_chainId, + mungers=None, + ) + """ property default_block """ @property def default_block(self) -> BlockIdentifier: @@ -253,6 +258,11 @@ def call_munger( class AsyncEth(BaseEth): is_async = True + @property + async def chain_id(self) -> int: + # types ignored b/c mypy conflict with BlockingEth properties + return await self._chain_id() # type: ignore + @property async def gas_price(self) -> Wei: # types ignored b/c mypy conflict with BlockingEth properties @@ -462,11 +472,6 @@ def blockNumber(self) -> BlockNumber: ) return self.block_number - _chain_id: Method[Callable[[], int]] = Method( - RPC.eth_chainId, - mungers=None, - ) - @property def chain_id(self) -> int: return self._chain_id() diff --git a/web3/manager.py b/web3/manager.py index b731ee7a81..1cd4e1c241 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -128,11 +128,11 @@ def default_middlewares( """ return [ (request_parameter_normalizer, 'request_param_normalizer'), # Delete - (gas_price_strategy_middleware, 'gas_price_strategy'), # Add Async + (gas_price_strategy_middleware, 'gas_price_strategy'), (name_to_address_middleware(web3), 'name_to_address'), # Add Async (attrdict_middleware, 'attrdict'), # Delete (pythonic_middleware, 'pythonic'), # Delete - (validation_middleware, 'validation'), # Add async + (validation_middleware, 'validation'), (abi_middleware, 'abi'), # Delete (buffered_gas_estimate_middleware, 'gas_estimate'), ] @@ -159,8 +159,8 @@ async def _coro_make_request( self.logger.debug("Making request. Method: %s", method) return await request_func(method, params) + @staticmethod def formatted_response( - self, response: RPCResponse, params: Any, error_formatters: Optional[Callable[..., Any]] = None, diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index baad91b1b6..465b1af9ce 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -69,6 +69,7 @@ make_stalecheck_middleware, ) from .validation import ( # noqa: F401 + async_validation_middleware, validation_middleware, ) diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 3542583a9a..252d3f8d2c 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -2,12 +2,13 @@ TYPE_CHECKING, Any, Callable, + Coroutine, + Literal, Optional, ) from eth_utils.toolz import ( assoc, - curry, merge, ) @@ -22,6 +23,12 @@ if TYPE_CHECKING: from web3 import Web3 # noqa: F401 +FORMATTER_DEFAULTS = { + "request_formatters": {}, + "result_formatters": {}, + "error_formatters": {}, +} + def construct_formatting_middleware( request_formatters: Optional[Formatters] = None, @@ -29,7 +36,7 @@ def construct_formatting_middleware( error_formatters: Optional[Formatters] = None ) -> Middleware: def ignore_web3_in_standard_formatters( - w3: "Web3", + w3: "Web3", method: RPCEndpoint, ) -> FormattersDict: return dict( request_formatters=request_formatters or {}, @@ -41,55 +48,88 @@ def ignore_web3_in_standard_formatters( def construct_web3_formatting_middleware( - web3_formatters_builder: Callable[["Web3"], FormattersDict] + web3_formatters_builder: Callable[["Web3", RPCEndpoint], FormattersDict], ) -> Middleware: def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" + make_request: Callable[[RPCEndpoint, Any], Any], + w3: "Web3", ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - formatters = merge( - { - "request_formatters": {}, - "result_formatters": {}, - "error_formatters": {}, - }, - web3_formatters_builder(w3), - ) - return apply_formatters(make_request=make_request, **formatters) + def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + formatters = merge( + FORMATTER_DEFAULTS, + web3_formatters_builder(w3, method), + ) + response = _make_request_with_formatters( + method=method, + params=params, + request_formatters=formatters.pop('request_formatters'), + ) + return _apply_response_formatters(method=method, response=response, **formatters) + def _make_request_with_formatters( + method: RPCEndpoint, params: Any, request_formatters: Formatters + ) -> RPCResponse: + if method in request_formatters: + formatter = request_formatters[method] + formatted_params = formatter(params) + return make_request(method, formatted_params) + return make_request(method, params) + + return middleware return formatter_middleware -@curry -def apply_formatters( +async def async_construct_web3_formatting_middleware( + async_web3_formatters_builder: + Callable[["Web3", RPCEndpoint], Coroutine[Any, Any, FormattersDict]] +) -> Callable[[Callable[[RPCEndpoint, Any], Any], "Web3"], + Coroutine[Any, Any, Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]]]: + async def formatter_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], + async_w3: "Web3", + ) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + formatters = ( + FORMATTER_DEFAULTS, + await async_web3_formatters_builder(async_w3, method), + ) + response = await _make_async_request_with_formatters( + method=method, + params=params, + request_formatters=formatters.pop('request_formatters'), + ) + return _apply_response_formatters(method=method, response=response, **formatters) + + async def _make_async_request_with_formatters( + method: RPCEndpoint, params: Any, request_formatters: Formatters + ) -> RPCResponse: + if method in request_formatters: + formatter = request_formatters[method] + formatted_params = formatter(params) + return await make_request(method, formatted_params) + return await make_request(method, params) + + return middleware + return formatter_middleware + + +def _apply_response_formatters( method: RPCEndpoint, - params: Any, - make_request: Callable[[RPCEndpoint, Any], RPCResponse], - request_formatters: Formatters, + response: RPCResponse, result_formatters: Formatters, error_formatters: Formatters, ) -> RPCResponse: - if method in request_formatters: - formatter = request_formatters[method] - formatted_params = formatter(params) - response = make_request(method, formatted_params) - else: - response = make_request(method, params) - - if "result" in response and method in result_formatters: - formatter = result_formatters[method] - formatted_response = assoc( - response, - "result", - formatter(response["result"]), + def _format_response( + response_type: Literal["result", "error"], + method_response_formatter: Callable[..., Any] + ) -> RPCResponse: + appropriate_response = response[response_type] + return assoc( + response, response_type, method_response_formatter(appropriate_response) ) - return formatted_response + if "result" in response and method in result_formatters: + return _format_response("result", result_formatters[method]) elif "error" in response and method in error_formatters: - formatter = error_formatters[method] - formatted_response = assoc( - response, - "error", - formatter(response["error"]), - ) - return formatted_response + return _format_response("error", error_formatters[method]) else: return response diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 2e956655f4..950c4cb0d2 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -2,6 +2,7 @@ TYPE_CHECKING, Any, Callable, + Dict, ) from eth_utils.curried import ( @@ -20,6 +21,7 @@ from hexbytes import ( HexBytes, ) +from toolz import assoc from web3._utils.formatters import ( hex_to_integer, @@ -32,10 +34,12 @@ ValidationError, ) from web3.middleware.formatting import ( + async_construct_web3_formatting_middleware, construct_web3_formatting_middleware, ) from web3.types import ( FormattersDict, + RPCEndpoint, TxParams, ) @@ -45,20 +49,19 @@ MAX_EXTRADATA_LENGTH = 32 is_not_null = complement(is_null) - to_integer_if_hex = apply_formatter_if(is_string, hex_to_integer) @curry -def validate_chain_id(web3: "Web3", chain_id: int) -> int: - if to_integer_if_hex(chain_id) == web3.eth.chain_id: +def validate_chain_id(web3_chain_id: int, chain_id: int) -> int: + if to_integer_if_hex(chain_id) == web3_chain_id: return chain_id else: raise ValidationError( "The transaction declared chain ID %r, " "but the connected node is on %r" % ( chain_id, - web3.eth.chain_id, + web3_chain_id, ) ) @@ -84,12 +87,12 @@ def transaction_normalizer(transaction: TxParams) -> TxParams: return dissoc(transaction, 'chainId') -def transaction_param_validator(web3: "Web3") -> Callable[..., Any]: +def transaction_param_validator(web3_chain_id: int) -> Callable[..., Any]: transactions_params_validators = { "chainId": apply_formatter_if( # Bypass `validate_chain_id` if chainId can't be determined - lambda _: is_not_null(web3.eth.chain_id), - validate_chain_id(web3), + lambda _: is_not_null(web3_chain_id), + validate_chain_id(web3_chain_id), ), } return apply_formatter_at_index( @@ -101,8 +104,6 @@ def transaction_param_validator(web3: "Web3") -> Callable[..., Any]: BLOCK_VALIDATORS = { 'extraData': check_extradata_length, } - - block_validator = apply_formatter_if( is_not_null, apply_formatters_to_dict(BLOCK_VALIDATORS) @@ -110,25 +111,51 @@ def transaction_param_validator(web3: "Web3") -> Callable[..., Any]: @curry -def chain_id_validator(web3: "Web3") -> Callable[..., Any]: +def chain_id_validator(web3_chain_id: int) -> Callable[..., Any]: return compose( apply_formatter_at_index(transaction_normalizer, 0), - transaction_param_validator(web3) + transaction_param_validator(web3_chain_id) ) -def build_validators_with_web3(w3: "Web3") -> FormattersDict: +def build_formatters_dict(request_formatters: Dict[RPCEndpoint, Any]) -> FormattersDict: return dict( - request_formatters={ - RPC.eth_sendTransaction: chain_id_validator(w3), - RPC.eth_estimateGas: chain_id_validator(w3), - RPC.eth_call: chain_id_validator(w3), - }, + request_formatters=request_formatters, result_formatters={ RPC.eth_getBlockByHash: block_validator, RPC.eth_getBlockByNumber: block_validator, - }, + } ) -validation_middleware = construct_web3_formatting_middleware(build_validators_with_web3) +METHODS_TO_VALIDATE = [ + RPC.eth_sendTransaction, + RPC.eth_estimateGas, + RPC.eth_call +] + + +def build_method_validators(w3: "Web3", method: RPCEndpoint) -> FormattersDict: + request_formatters = {} + if RPCEndpoint(method) in METHODS_TO_VALIDATE: + w3_chain_id = w3.eth.chain_id + for method in METHODS_TO_VALIDATE: + request_formatters[method] = chain_id_validator(w3_chain_id) + + return build_formatters_dict(request_formatters) + + +async def async_build_method_validators(async_w3: "Web3", method: RPCEndpoint) -> FormattersDict: + request_formatters = {} + if RPCEndpoint(method) in METHODS_TO_VALIDATE: + w3_chain_id = await async_w3.eth.chain_id # type: ignore + for method in METHODS_TO_VALIDATE: + request_formatters = assoc(request_formatters, method, chain_id_validator(w3_chain_id)) + + return build_formatters_dict(request_formatters) + + +validation_middleware = construct_web3_formatting_middleware(build_method_validators) +async_validation_middleware = async_construct_web3_formatting_middleware( + async_build_method_validators +) diff --git a/web3/tools/benchmark/main.py b/web3/tools/benchmark/main.py index 1f797a270f..3d63224d54 100644 --- a/web3/tools/benchmark/main.py +++ b/web3/tools/benchmark/main.py @@ -29,8 +29,10 @@ from web3.middleware import ( async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_validation_middleware, buffered_gas_estimate_middleware, gas_price_strategy_middleware, + validation_middleware, ) from web3.tools.benchmark.node import ( GethBenchmarkFixture, @@ -65,7 +67,11 @@ def build_web3_http(endpoint_uri: str) -> Web3: wait_for_http(endpoint_uri) _web3 = Web3( HTTPProvider(endpoint_uri), - middlewares=[gas_price_strategy_middleware, buffered_gas_estimate_middleware] + middlewares=[ + buffered_gas_estimate_middleware, + gas_price_strategy_middleware, + await validation_middleware, + ] ) return _web3 @@ -74,7 +80,11 @@ async def build_async_w3_http(endpoint_uri: str) -> Web3: await wait_for_aiohttp(endpoint_uri) _web3 = Web3( AsyncHTTPProvider(endpoint_uri), # type: ignore - middlewares=[async_gas_price_strategy_middleware, async_buffered_gas_estimate_middleware], + middlewares=[ + async_buffered_gas_estimate_middleware, + async_gas_price_strategy_middleware, + await async_validation_middleware(), + ], modules={"eth": (AsyncEth,)}, ) return _web3