Skip to content

Commit

Permalink
Mirror sync logic for async name to address
Browse files Browse the repository at this point in the history
- There wasn't too much code copying in the end with this approach so this seems better. This provides proper recursion for nested address list types (address[][][], address[2][1], etc) and gives async and sync the same treatment as they basically follow the same approach.
  • Loading branch information
fselmo committed Jun 30, 2023
1 parent 3ab1509 commit c4bcdb7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 19 deletions.
79 changes: 79 additions & 0 deletions web3/_utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import itertools
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Coroutine,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -53,6 +55,7 @@
decode_hex,
is_bytes,
is_list_like,
is_string,
is_text,
to_text,
to_tuple,
Expand All @@ -66,6 +69,9 @@
pipe,
)

from web3._utils.decorators import (
reject_recursive_repeats,
)
from web3._utils.ens import (
is_ens_name,
)
Expand All @@ -82,11 +88,17 @@
ABIEventParams,
ABIFunction,
ABIFunctionParams,
TReturn,
)
from web3.utils import ( # public utils module
get_abi_input_names,
)

if TYPE_CHECKING:
from web3 import ( # noqa: F401
AsyncWeb3,
)


def filter_by_type(_type: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:
return [abi for abi in contract_abi if abi["type"] == _type]
Expand Down Expand Up @@ -971,3 +983,70 @@ def __new__(self, args: Any) -> "ABIDecodedNamedTuple":
return super().__new__(self, *args)

return ABIDecodedNamedTuple


# -- async -- #


async def async_data_tree_map(
async_w3: "AsyncWeb3",
func: Callable[
["AsyncWeb3", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]]
],
data_tree: Any,
) -> "ABITypedData":
"""
Map an awaitable method to every ABITypedData element in the tree.
The awaitable method should receive three positional args:
async_w3, abi_type, and data
"""

async def async_map_to_typed_data(elements: Any) -> "ABITypedData":
if isinstance(elements, ABITypedData) and elements.abi_type is not None:
formatted = await func(async_w3, *elements)
return ABITypedData(formatted)
else:
return elements

return await async_recursive_map(async_w3, async_map_to_typed_data, data_tree)


@reject_recursive_repeats
async def async_recursive_map(
async_w3: "AsyncWeb3",
func: Callable[[Any], Coroutine[Any, Any, TReturn]],
data: Any,
) -> TReturn:
"""
Apply an awaitable method to data and any collection items inside data
(using async_map_collection).
Define the awaitable method so that it only applies to the type of value that you
want it to apply to.
"""

async def async_recurse(item: Any) -> TReturn:
return await async_recursive_map(async_w3, func, item)

items_mapped = await async_map_if_collection(async_recurse, data)
return await func(items_mapped)


async def async_map_if_collection(
func: Callable[[Any], Coroutine[Any, Any, Any]], value: Any
) -> Any:
"""
Apply an awaitable method to each element of a collection or value of a dictionary.
If the value is not a collection, return it unmodified.
"""

datatype = type(value)
if isinstance(value, Mapping):
return datatype({key: await func(val) for key, val in value.values()})
if is_string(value):
return value
elif isinstance(value, Iterable):
return datatype([await func(item) for item in value])
else:
return value
32 changes: 13 additions & 19 deletions web3/middleware/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

from .._utils.abi import (
abi_data_tree,
async_data_tree_map,
strip_abi_type,
)
from .._utils.formatters import (
recursive_map,
)
from .formatting import (
construct_formatting_middleware,
Expand Down Expand Up @@ -52,26 +57,14 @@ async def async_format_all_ens_names_to_address(
abi_types_for_method: Sequence[Any],
data: Sequence[Any],
) -> Sequence[Any]:
# provide a stepwise version of what the curried formatters do
abi_typed_params = abi_data_tree(abi_types_for_method, data)

formatted_params = []
for param in abi_typed_params:
if param.abi_type == "address[]":
# handle name conversion in an address list
# Note: only supports single list atm, as is true the sync middleware
# TODO: handle address[][], etc...
formatted_data = await async_format_all_ens_names_to_address(
async_web3,
[param.abi_type[:-2]] * len(param.data),
[subparam.data for subparam in param.data],
)
else:
_abi_type, formatted_data = await async_abi_ens_resolver(
async_web3,
param.abi_type,
param.data,
)
formatted_params.append(formatted_data)
formatted_data_tree = await async_data_tree_map(
async_web3,
async_abi_ens_resolver,
abi_typed_params,
)
formatted_params = recursive_map(strip_abi_type, formatted_data_tree)
return formatted_params


Expand All @@ -96,6 +89,7 @@ async def async_apply_ens_to_address_conversion(
)
formatted_dict = dict(zip(fields, formatted_params))
return (formatted_dict,)

else:
raise TypeError(
f"ABI definitions must be a list or dictionary, "
Expand Down

0 comments on commit c4bcdb7

Please sign in to comment.