From c4bcdb7e5d460665347193e6cb3e5837f01ee7f3 Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 30 Jun 2023 17:09:08 -0600 Subject: [PATCH] Mirror sync logic for async name to address - There wasn't too much code copying in the end with this approach so this seems better. This provides proper recursion for nested address list types (address[][][], address[2][1], etc) and gives async and sync the same treatment as they basically follow the same approach. --- web3/_utils/abi.py | 79 ++++++++++++++++++++++++++++++++++++++++ web3/middleware/names.py | 32 +++++++--------- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/web3/_utils/abi.py b/web3/_utils/abi.py index cfaf9cc5d7..4b0893d9cc 100644 --- a/web3/_utils/abi.py +++ b/web3/_utils/abi.py @@ -7,9 +7,11 @@ import itertools import re from typing import ( + TYPE_CHECKING, Any, Callable, Collection, + Coroutine, Dict, Iterable, List, @@ -53,6 +55,7 @@ decode_hex, is_bytes, is_list_like, + is_string, is_text, to_text, to_tuple, @@ -66,6 +69,9 @@ pipe, ) +from web3._utils.decorators import ( + reject_recursive_repeats, +) from web3._utils.ens import ( is_ens_name, ) @@ -82,11 +88,17 @@ ABIEventParams, ABIFunction, ABIFunctionParams, + TReturn, ) from web3.utils import ( # public utils module get_abi_input_names, ) +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + ) + def filter_by_type(_type: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]: return [abi for abi in contract_abi if abi["type"] == _type] @@ -971,3 +983,70 @@ def __new__(self, args: Any) -> "ABIDecodedNamedTuple": return super().__new__(self, *args) return ABIDecodedNamedTuple + + +# -- async -- # + + +async def async_data_tree_map( + async_w3: "AsyncWeb3", + func: Callable[ + ["AsyncWeb3", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]] + ], + data_tree: Any, +) -> "ABITypedData": + """ + Map an awaitable method to every ABITypedData element in the tree. + + The awaitable method should receive three positional args: + async_w3, abi_type, and data + """ + + async def async_map_to_typed_data(elements: Any) -> "ABITypedData": + if isinstance(elements, ABITypedData) and elements.abi_type is not None: + formatted = await func(async_w3, *elements) + return ABITypedData(formatted) + else: + return elements + + return await async_recursive_map(async_w3, async_map_to_typed_data, data_tree) + + +@reject_recursive_repeats +async def async_recursive_map( + async_w3: "AsyncWeb3", + func: Callable[[Any], Coroutine[Any, Any, TReturn]], + data: Any, +) -> TReturn: + """ + Apply an awaitable method to data and any collection items inside data + (using async_map_collection). + + Define the awaitable method so that it only applies to the type of value that you + want it to apply to. + """ + + async def async_recurse(item: Any) -> TReturn: + return await async_recursive_map(async_w3, func, item) + + items_mapped = await async_map_if_collection(async_recurse, data) + return await func(items_mapped) + + +async def async_map_if_collection( + func: Callable[[Any], Coroutine[Any, Any, Any]], value: Any +) -> Any: + """ + Apply an awaitable method to each element of a collection or value of a dictionary. + If the value is not a collection, return it unmodified. + """ + + datatype = type(value) + if isinstance(value, Mapping): + return datatype({key: await func(val) for key, val in value.values()}) + if is_string(value): + return value + elif isinstance(value, Iterable): + return datatype([await func(item) for item in value]) + else: + return value diff --git a/web3/middleware/names.py b/web3/middleware/names.py index f36cf3711f..b022f4b68a 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -23,6 +23,11 @@ from .._utils.abi import ( abi_data_tree, + async_data_tree_map, + strip_abi_type, +) +from .._utils.formatters import ( + recursive_map, ) from .formatting import ( construct_formatting_middleware, @@ -52,26 +57,14 @@ async def async_format_all_ens_names_to_address( abi_types_for_method: Sequence[Any], data: Sequence[Any], ) -> Sequence[Any]: + # provide a stepwise version of what the curried formatters do abi_typed_params = abi_data_tree(abi_types_for_method, data) - - formatted_params = [] - for param in abi_typed_params: - if param.abi_type == "address[]": - # handle name conversion in an address list - # Note: only supports single list atm, as is true the sync middleware - # TODO: handle address[][], etc... - formatted_data = await async_format_all_ens_names_to_address( - async_web3, - [param.abi_type[:-2]] * len(param.data), - [subparam.data for subparam in param.data], - ) - else: - _abi_type, formatted_data = await async_abi_ens_resolver( - async_web3, - param.abi_type, - param.data, - ) - formatted_params.append(formatted_data) + formatted_data_tree = await async_data_tree_map( + async_web3, + async_abi_ens_resolver, + abi_typed_params, + ) + formatted_params = recursive_map(strip_abi_type, formatted_data_tree) return formatted_params @@ -96,6 +89,7 @@ async def async_apply_ens_to_address_conversion( ) formatted_dict = dict(zip(fields, formatted_params)) return (formatted_dict,) + else: raise TypeError( f"ABI definitions must be a list or dictionary, "