Skip to content

Commit

Permalink
Changes since PR #3048 was put up for review
Browse files Browse the repository at this point in the history
- Add web3.providers.websocket.rst to .gitignore
- Put back formatters for eth_getCode / remove unnecessary ``compose()``
- Add read-friendly comment splitting Web3 from AsyncWeb3 in main.py
- Use correct class name in docstring
- Friendlier message when exception is raised connecting to websocket endpoint
- Friendlier message for websocket restricted_kwargs; use a merge of default + provided websocket_kwargs with the provided values taking precedence
- Validate "ws://" / "wss://" in websocket endpoint
  • Loading branch information
fselmo committed Jul 20, 2023
1 parent 8b46197 commit 054bcf1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ docs/web3.gas_strategies.rst
docs/web3.middleware.rst
docs/web3.providers.eth_tester.rst
docs/web3.providers.rst
docs/web3.providers.websocket.rst
docs/web3.rst
docs/web3.scripts.release.rst
docs/web3.scripts.rst
Expand Down
29 changes: 0 additions & 29 deletions docs/web3.providers.websocket.rst

This file was deleted.

2 changes: 1 addition & 1 deletion web3/_utils/method_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def apply_list_to_array_formatter(formatter: Any) -> Callable[..., Any]:
to_hex_if_integer,
0,
),
RPC.eth_getCode: compose(apply_formatter_at_index(to_hex_if_integer, 1)),
RPC.eth_getCode: apply_formatter_at_index(to_hex_if_integer, 1),
RPC.eth_getStorageAt: apply_formatter_at_index(to_hex_if_integer, 2),
RPC.eth_getTransactionByBlockNumberAndIndex: compose(
apply_formatter_at_index(to_hex_if_integer, 0),
Expand Down
5 changes: 4 additions & 1 deletion web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def ens(self, new_ens: Union[ENS, "Empty"]) -> None:
self._ens = new_ens


# -- async -- #


class AsyncWeb3(BaseWeb3):
# mypy Types
eth: AsyncEth
Expand Down Expand Up @@ -505,7 +508,7 @@ def persistent_websocket(
) -> "_PersistentConnectionWeb3":
"""
Establish a persistent connection via websockets to a websocket provider using
a WebsocketProviderV2 instance.
a ``PersistentConnectionProvider`` instance.
"""
return _PersistentConnectionWeb3(
provider,
Expand Down
28 changes: 23 additions & 5 deletions web3/providers/websocket/websocket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from eth_typing import (
URI,
)
from toolz import (
merge,
)
from websockets.client import (
connect,
)
Expand All @@ -31,6 +34,7 @@
DEFAULT_PING_INTERVAL = 30 # 30 seconds
DEFAULT_PING_TIMEOUT = 300 # 5 minutes

VALID_WEBSOCKET_URI_PREFIXES = {"ws://", "wss://"}
RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"}
DEFAULT_WEBSOCKET_KWARGS = {
# set how long to wait between pings from the server
Expand Down Expand Up @@ -58,19 +62,28 @@ def __init__(
if self.endpoint_uri is None:
self.endpoint_uri = get_default_endpoint()

if not any(
self.endpoint_uri.startswith(prefix)
for prefix in VALID_WEBSOCKET_URI_PREFIXES
):
raise Web3ValidationError(
f"Websocket endpoint uri must begin with 'ws://' or 'wss://': "
f"{self.endpoint_uri}"
)

if websocket_kwargs is None:
websocket_kwargs = DEFAULT_WEBSOCKET_KWARGS
self.websocket_kwargs = DEFAULT_WEBSOCKET_KWARGS
else:
found_restricted_keys = set(websocket_kwargs).intersection(
RESTRICTED_WEBSOCKET_KWARGS
)
if found_restricted_keys:
raise Web3ValidationError(
f"{RESTRICTED_WEBSOCKET_KWARGS} are not allowed "
f"in websocket_kwargs, found: {found_restricted_keys}"
f"Found restricted keys for websocket_kwargs: "
f"{found_restricted_keys}."
)
self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs)

self.websocket_kwargs = websocket_kwargs
super().__init__(endpoint_uri, call_timeout=call_timeout)

def __str__(self) -> str:
Expand All @@ -93,7 +106,12 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
return False

async def connect(self) -> None:
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
try:
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
except Exception as e:
raise ProviderConnectionError(
f"Could not connect to endpoint: {self.endpoint_uri}"
) from e

async def disconnect(self) -> None:
await self.ws.close()
Expand Down

0 comments on commit 054bcf1

Please sign in to comment.