diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 805df0a9d6..793daffd6c 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -88,7 +88,7 @@ async def async_w3(geth_process, endpoint_uri): async_gas_price_strategy_middleware, async_buffered_gas_estimate_middleware ], - modules={'eth': (AsyncEth,), 'net': (AsyncNet,)}) + modules={'eth': (AsyncEth,), 'async_net': (AsyncNet,)}) return _web3 diff --git a/web3/_utils/module_testing/net_module.py b/web3/_utils/module_testing/net_module.py index 515f445a23..57deed74cc 100644 --- a/web3/_utils/module_testing/net_module.py +++ b/web3/_utils/module_testing/net_module.py @@ -44,19 +44,19 @@ def test_net_chainId_deprecation(self, web3: "Web3") -> None: class AsyncNetModuleTest: @pytest.mark.asyncio async def test_net_version(self, async_w3: "Web3") -> None: - version = await async_w3.net.version # type: ignore + version = await async_w3.async_net.version assert is_string(version) assert version.isdigit() @pytest.mark.asyncio async def test_net_listening(self, async_w3: "Web3") -> None: - listening = await async_w3.net.listening + listening = await async_w3.async_net.listening assert is_boolean(listening) @pytest.mark.asyncio async def test_net_peer_count(self, async_w3: "Web3") -> None: - peer_count = await async_w3.net.peer_count # type: ignore + peer_count = await async_w3.async_net.peer_count assert is_integer(peer_count) diff --git a/web3/main.py b/web3/main.py index 6d90859e86..bc2a182c23 100644 --- a/web3/main.py +++ b/web3/main.py @@ -226,7 +226,7 @@ def toChecksumAddress(value: Union[AnyAddress, str, bytes]) -> ChecksumAddress: parity: Parity geth: Geth net: Net - # async_net: AsyncNet + async_net: AsyncNet def __init__( self, diff --git a/web3/net.py b/web3/net.py index 61b5d0dace..5593ecd1bc 100644 --- a/web3/net.py +++ b/web3/net.py @@ -1,4 +1,5 @@ from typing import ( + Awaitable, Callable, NoReturn, ) @@ -16,7 +17,7 @@ ) -class BaseNet(Module): +class Net(Module): _listening: Method[Callable[[], bool]] = Method( RPC.net_listening, mungers=[default_root_munger], @@ -33,52 +34,52 @@ class BaseNet(Module): ) @property - def peer_count(self) -> int: - return self._peer_count() - - @property - def version(self) -> str: - return self._version() + def chainId(self) -> NoReturn: + raise DeprecationWarning("This method has been deprecated in EIP 1474.") @property def listening(self) -> bool: return self._listening() - -class Net(BaseNet): @property - def chainId(self) -> NoReturn: - raise DeprecationWarning("This method has been deprecated in EIP 1474.") - - # @property - # def listening(self) -> bool: - # return self._listening() - - # @property - # def peer_count(self) -> int: - # return self._peer_count() - - # @property - # def version(self) -> str: - # return self._version() + def peer_count(self) -> int: + return self._peer_count() + @property + def version(self) -> str: + return self._version() # # Deprecated Methods # - peerCount = DeprecatedMethod(BaseNet.peer_count, 'peerCount', 'peer_count') # type: ignore + peerCount = DeprecatedMethod(peer_count, 'peerCount', 'peer_count') # type: ignore -class AsyncNet(BaseNet): +class AsyncNet(Module): is_async = True - # @property - # def listening(self) -> bool: - # return self._listening() + _listening: Method[Callable[[], Awaitable[bool]]] = Method( + RPC.net_listening, + mungers=[default_root_munger], + ) - # @property - # def peer_count(self) -> int: - # return self._peer_count() + _peer_count: Method[Callable[[], Awaitable[int]]] = Method( + RPC.net_peerCount, + mungers=[default_root_munger], + ) - # @property - # def version(self) -> str: - # return self._version() + _version: Method[Callable[[], Awaitable[str]]] = Method( + RPC.net_version, + mungers=[default_root_munger], + ) + + @property + async def listening(self) -> bool: + return await self._listening() + + @property + async def peer_count(self) -> int: + return await self._peer_count() + + @property + async def version(self) -> str: + return await self._version()