Skip to content

Commit

Permalink
Changes since PR ethereum#3048 was put up for review
Browse files Browse the repository at this point in the history
- Add web3.providers.websocket.rst to .gitignore
- Put back formatters for eth_getCode / remove unnecessary ``compose()``
- Add read-friendly comment splitting Web3 from AsyncWeb3 in main.py
- Use correct class name in docstring
- Friendlier message when exception is raised connecting to websocket endpoint
- Friendlier message for websocket restricted_kwargs; use a merge of default + provided websocket_kwargs with the provided values taking precedence
- Validate "ws://" / "wss://" in websocket endpoint
- Use Dict[str, Any] for websocket_kwargs type
  • Loading branch information
fselmo committed Jul 20, 2023
1 parent 8b46197 commit 62366e9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ docs/web3.gas_strategies.rst
docs/web3.middleware.rst
docs/web3.providers.eth_tester.rst
docs/web3.providers.rst
docs/web3.providers.websocket.rst
docs/web3.rst
docs/web3.scripts.release.rst
docs/web3.scripts.rst
Expand Down
29 changes: 0 additions & 29 deletions docs/web3.providers.websocket.rst

This file was deleted.

2 changes: 1 addition & 1 deletion web3/_utils/method_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def apply_list_to_array_formatter(formatter: Any) -> Callable[..., Any]:
to_hex_if_integer,
0,
),
RPC.eth_getCode: compose(apply_formatter_at_index(to_hex_if_integer, 1)),
RPC.eth_getCode: apply_formatter_at_index(to_hex_if_integer, 1),
RPC.eth_getStorageAt: apply_formatter_at_index(to_hex_if_integer, 2),
RPC.eth_getTransactionByBlockNumberAndIndex: compose(
apply_formatter_at_index(to_hex_if_integer, 0),
Expand Down
5 changes: 4 additions & 1 deletion web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def ens(self, new_ens: Union[ENS, "Empty"]) -> None:
self._ens = new_ens


# -- async -- #


class AsyncWeb3(BaseWeb3):
# mypy Types
eth: AsyncEth
Expand Down Expand Up @@ -505,7 +508,7 @@ def persistent_websocket(
) -> "_PersistentConnectionWeb3":
"""
Establish a persistent connection via websockets to a websocket provider using
a WebsocketProviderV2 instance.
a ``PersistentConnectionProvider`` instance.
"""
return _PersistentConnectionWeb3(
provider,
Expand Down
34 changes: 26 additions & 8 deletions web3/providers/websocket/websocket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
import os
from typing import (
Any,
Dict,
Optional,
Union,
)

from eth_typing import (
URI,
)
from toolz import (
merge,
)
from websockets.client import (
connect,
)
Expand All @@ -31,6 +35,7 @@
DEFAULT_PING_INTERVAL = 30 # 30 seconds
DEFAULT_PING_TIMEOUT = 300 # 5 minutes

VALID_WEBSOCKET_URI_PREFIXES = {"ws://", "wss://"}
RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"}
DEFAULT_WEBSOCKET_KWARGS = {
# set how long to wait between pings from the server
Expand All @@ -51,26 +56,34 @@ class WebsocketProviderV2(PersistentConnectionProvider):
def __init__(
self,
endpoint_uri: Optional[Union[URI, str]] = None,
websocket_kwargs: Optional[Any] = None,
websocket_kwargs: Optional[Dict[str, Any]] = None,
call_timeout: Optional[int] = None,
) -> None:
self.endpoint_uri = URI(endpoint_uri)
if self.endpoint_uri is None:
self.endpoint_uri = get_default_endpoint()

if websocket_kwargs is None:
websocket_kwargs = DEFAULT_WEBSOCKET_KWARGS
else:
if not any(
self.endpoint_uri.startswith(prefix)
for prefix in VALID_WEBSOCKET_URI_PREFIXES
):
raise Web3ValidationError(
f"Websocket endpoint uri must begin with 'ws://' or 'wss://': "
f"{self.endpoint_uri}"
)

if websocket_kwargs is not None:
found_restricted_keys = set(websocket_kwargs).intersection(
RESTRICTED_WEBSOCKET_KWARGS
)
if found_restricted_keys:
raise Web3ValidationError(
f"{RESTRICTED_WEBSOCKET_KWARGS} are not allowed "
f"in websocket_kwargs, found: {found_restricted_keys}"
f"Found restricted keys for websocket_kwargs: "
f"{found_restricted_keys}."
)

self.websocket_kwargs = websocket_kwargs
self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})

super().__init__(endpoint_uri, call_timeout=call_timeout)

def __str__(self) -> str:
Expand All @@ -93,7 +106,12 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
return False

async def connect(self) -> None:
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
try:
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
except Exception as e:
raise ProviderConnectionError(
f"Could not connect to endpoint: {self.endpoint_uri}"
) from e

async def disconnect(self) -> None:
await self.ws.close()
Expand Down

0 comments on commit 62366e9

Please sign in to comment.