diff --git a/docs/providers.rst b/docs/providers.rst index 04e161967d..020bc0d663 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -272,7 +272,7 @@ asynchronous context manager, can be found in the `websockets connection`_ docs. ... ... unsubscribed = False ... while not unsubscribed: - ... async for response in w3.listen_to_websocket(): + ... async for response in w3.ws.listen_to_websocket(): ... print(f"{response}\n") ... # handle responses here ... @@ -320,6 +320,52 @@ and reconnect automatically if the connection is lost. A similar example, using >>> asyncio.run(ws_v2_subscription_iterator_example()) +_PersistentConnectionWeb3 via AsyncWeb3.persistent_websocket() +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When an ``AsyncWeb3`` class is connected to a persistent websocket connection, via the +``persistent_websocket()`` method, it becomes an instance of the +``_PersistentConnectionWeb3`` class. This class has a few additional methods and +attributes that are not available on the ``AsyncWeb3`` class. + +.. py:attribute:: _PersistentConnectionWeb3.ws + + Listening to websocket responses, and sending raw requests, can be done using the + ``ws`` attribute of the ``_PersistentConnectionWeb3`` class. The ``ws`` attribute + houses a public API, a :class:`~web3.providers.websocket.WebsocketConnection` class, + for sending and receiving websocket messages. + + .. py:class:: web3.providers.websocket.WebsocketConnection() + + This class handles interactions with a websocket connection. It is available + via the ``ws`` attribute of the ``_PersistentConnectionWeb3`` class. The + ``WebsocketConnection`` class has the following methods: + + .. py:method:: listen_to_websocket() + + This method is available for listening to websocket responses indefinitely. + It is an asynchronous generator that yields responses from the websocket + connection. The responses from this method are formatted by web3.py + formatters and run through the middlewares before being yielded. + An example of its use can be seen above in the `Usage`_ section. + + .. py:method:: recv() + + The ``recv()`` method can be used to receive the next message from the + websocket. The response from this method is formatted by web3.py formatters + and run through the middlewares before being returned. This is useful for + receiving singled responses for one-to-many requests such receiving the + next ``eth_subscribe`` subscription response. + + .. py:method:: send(method: RPCEndpoint, params: Sequence[Any]) + + This method is available strictly for sending raw requests to the websocket, + if desired. It is not recommended to use this method directly, as the + responses will not be formatted by web3.py formatters or run through the + middlewares. Instead, use the methods available on the respective web3 + module. For example, use ``w3.eth.get_block("latest")`` instead of + ``w3.ws.send("eth_getBlockByNumber", ["latest", True])``. + AutoProvider ~~~~~~~~~~~~ diff --git a/newsfragments/3096.breaking.rst b/newsfragments/3096.breaking.rst new file mode 100644 index 0000000000..4731e4e59e --- /dev/null +++ b/newsfragments/3096.breaking.rst @@ -0,0 +1 @@ +Breaking change to the API for interacting with a persistent websocket connection via ``AsyncWeb3`` and ``WebsocketProviderV2``. This change internalizes the ``provider.ws`` property and opts for a ``w3.ws`` API achieved via a new ``WebsocketConnection`` class. With these changes, ``eth_subscription`` messages now return the subscription id as the ``subscription`` param and the formatted message as the ``result`` param. diff --git a/newsfragments/3096.docs.rst b/newsfragments/3096.docs.rst new file mode 100644 index 0000000000..14a50066ad --- /dev/null +++ b/newsfragments/3096.docs.rst @@ -0,0 +1 @@ +Update ``WebsocketProviderV2`` documentation to reflect the new public websocket API via the ``WebsocketConnection`` class. diff --git a/newsfragments/3096.feature.rst b/newsfragments/3096.feature.rst new file mode 100644 index 0000000000..f86246798e --- /dev/null +++ b/newsfragments/3096.feature.rst @@ -0,0 +1 @@ +Sync responses for ``WebsocketProviderV2`` open connections with requests via matching RPC ``id`` values. diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index 156be338d4..630cc00e1a 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -16,6 +16,9 @@ from eth_typing import ( HexStr, ) +from eth_utils import ( + is_hexstr, +) from eth_utils.curried import ( apply_formatter_at_index, apply_formatter_if, @@ -201,6 +204,7 @@ def type_aware_apply_formatters_to_dict_keys_and_values( "to": apply_formatter_if(is_address, to_checksum_address), "hash": to_hexbytes(32), "v": apply_formatter_if(is_not_null, to_integer_if_hex), + "yParity": apply_formatter_if(is_not_null, to_integer_if_hex), "standardV": apply_formatter_if(is_not_null, to_integer_if_hex), "type": apply_formatter_if(is_not_null, to_integer_if_hex), "chainId": apply_formatter_if(is_not_null, to_integer_if_hex), @@ -612,40 +616,61 @@ def apply_list_to_array_formatter(formatter: Any) -> Callable[..., Any]: # -- eth_subscribe -- # def subscription_formatter(value: Any) -> Union[HexBytes, HexStr, Dict[str, Any]]: - if is_string(value): - if len(value.replace("0x", "")) == 64: - # transaction hash, from `newPendingTransactions` subscription w/o full_txs - return HexBytes(value) - + if is_hexstr(value): # subscription id from the original subscription request return HexStr(value) - response_key_set = set(value.keys()) - - # handle dict subscription responses - if either_set_is_a_subset(response_key_set, set(BLOCK_FORMATTERS.keys())): - # block format, newHeads - return block_formatter(value) + elif isinstance(value, dict): + # subscription messages - elif either_set_is_a_subset(response_key_set, set(LOG_ENTRY_FORMATTERS.keys())): - # logs - return log_entry_formatter(value) + result = value.get("result") + result_formatter = None - elif either_set_is_a_subset( - response_key_set, set(TRANSACTION_RESULT_FORMATTERS.keys()) - ): - # transaction subscription type (newPendingTransactions), full transactions - return transaction_result_formatter(value) - - elif any(_ in response_key_set for _ in {"syncing", "status"}): - # geth syncing response - return type_aware_apply_formatters_to_dict(GETH_SYNCING_SUBSCRIPTION_FORMATTERS) - - elif either_set_is_a_subset(response_key_set, set(SYNCING_FORMATTERS.keys())): - # syncing response object - return syncing_formatter + if isinstance(result, str) and len(result.replace("0x", "")) == 64: + # transaction hash, from `newPendingTransactions` subscription w/o full_txs + result_formatter = HexBytes + + elif isinstance(result, (dict, AttributeDict)): + result_key_set = set(result.keys()) + + # handle dict subscription responses + if either_set_is_a_subset( + result_key_set, + set(BLOCK_FORMATTERS.keys()), + percentage=90, + ): + # block format, newHeads + result_formatter = block_formatter + + elif either_set_is_a_subset( + result_key_set, set(LOG_ENTRY_FORMATTERS.keys()), percentage=90 + ): + # logs + result_formatter = log_entry_formatter + + elif either_set_is_a_subset( + result_key_set, set(TRANSACTION_RESULT_FORMATTERS.keys()), percentage=90 + ): + # newPendingTransactions, full transactions + result_formatter = transaction_result_formatter + + elif any(_ in result_key_set for _ in {"syncing", "status"}): + # geth syncing response + result_formatter = type_aware_apply_formatters_to_dict( + GETH_SYNCING_SUBSCRIPTION_FORMATTERS + ) + + elif either_set_is_a_subset( + result_key_set, + set(SYNCING_FORMATTERS.keys()), + percentage=90, + ): + # syncing response object + result_formatter = syncing_formatter + + if result_formatter is not None: + value["result"] = result_formatter(result) - # fallback to returning the value as-is return value diff --git a/web3/_utils/utility_methods.py b/web3/_utils/utility_methods.py index 5b622b13e6..66ad373ff2 100644 --- a/web3/_utils/utility_methods.py +++ b/web3/_utils/utility_methods.py @@ -57,15 +57,26 @@ def none_in_dict( return not any_in_dict(values, d) -def either_set_is_a_subset(set1: Set[Any], set2: Set[Any]) -> bool: +def either_set_is_a_subset( + set1: Set[Any], + set2: Set[Any], + percentage: int = 100, +) -> bool: """ Returns a bool based on whether two sets might have some differences but are mostly the same. This can be useful when comparing formatters to an actual response for formatting. - :param set1: A set of values - :param set2: A second set of values - :return: True if the intersection of the two sets is equal to the first set; - False if the intersection of the two sets is NOT equal to the first set + :param set1: A set of values. + :param set2: A second set of values. + :param percentage: The percentage of either set that must be present in the + other set; defaults to 100. + :return: True if one set's intersection with the other set is greater + than or equal to the given percentage of the other set. """ - return set1.intersection(set2) == set1 or set2.intersection(set1) == set2 + threshold = percentage / 100 + + return ( + len(set1.intersection(set2)) >= len(set1) * threshold + or len(set2.intersection(set1)) >= len(set2) * threshold + ) diff --git a/web3/main.py b/web3/main.py index 4191fc795e..b97aea0c51 100644 --- a/web3/main.py +++ b/web3/main.py @@ -60,6 +60,9 @@ build_strict_registry, map_abi_data, ) +from web3._utils.compat import ( + Self, +) from web3._utils.empty import ( empty, ) @@ -127,6 +130,7 @@ from web3.providers.websocket import ( WebsocketProvider, ) +from web3.providers.websocket.websocket_connection import WebsocketConnection from web3.testing import ( Testing, ) @@ -142,9 +146,6 @@ if TYPE_CHECKING: from web3.pm import PM # noqa: F401 from web3._utils.empty import Empty # noqa: F401 - from web3.manager import ( # noqa: F401 - _AsyncPersistentRecvStream, - ) def get_async_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]: @@ -538,9 +539,10 @@ def __init__( "Provider must inherit from PersistentConnectionProvider class." ) AsyncWeb3.__init__(self, provider, middlewares, modules, external_modules, ens) + self.ws = WebsocketConnection(self) # async for w3 in w3.persistent_websocket(provider) - async def __aiter__(self) -> AsyncIterator["_PersistentConnectionWeb3"]: + async def __aiter__(self) -> AsyncIterator[Self]: while True: try: yield self @@ -549,7 +551,7 @@ async def __aiter__(self) -> AsyncIterator["_PersistentConnectionWeb3"]: continue # async with w3.persistent_websocket(provider) as w3 - async def __aenter__(self) -> "_PersistentConnectionWeb3": + async def __aenter__(self) -> Self: await self.provider.connect() return self @@ -560,6 +562,3 @@ async def __aexit__( exc_tb: TracebackType, ) -> None: await self.provider.disconnect() - - def listen_to_websocket(self) -> "_AsyncPersistentRecvStream": - return self.manager.persistent_recv_stream() diff --git a/web3/manager.py b/web3/manager.py index 3e94f30d91..a0dc46ac96 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -1,5 +1,3 @@ -import asyncio -import json import logging from typing import ( TYPE_CHECKING, @@ -24,6 +22,9 @@ ConnectionClosedOK, ) +from web3._utils.async_caching import ( + async_lock, +) from web3._utils.caching import ( generate_cache_key, ) @@ -64,7 +65,7 @@ ) if TYPE_CHECKING: - from web3 import ( # noqa: F401 + from web3.main import ( # noqa: F401 AsyncWeb3, Web3, ) @@ -146,6 +147,12 @@ def __init__( self.middleware_onion = NamedElementOnion(middlewares) + if isinstance(provider, PersistentConnectionProvider): + # set up the request processor to be able to properly process ordered + # responses from the persistent connection as FIFO + provider = cast(PersistentConnectionProvider, self.provider) + self._request_processor = provider._request_processor + w3: Union["AsyncWeb3", "Web3"] = None _provider = None @@ -284,12 +291,17 @@ def formatted_response( apply_null_result_formatters(null_result_formatters, response, params) return response.get("result") - # Response from eth_subscribe includes response["params"]["result"] + # Response from eth_subscription includes response["params"]["result"] elif ( - response.get("params") is not None + response.get("method") == "eth_subscription" + and response.get("params") is not None + and response["params"].get("subscription") is not None and response["params"].get("result") is not None ): - return response["params"]["result"] + return { + "subscription": response["params"]["subscription"], + "result": response["params"]["result"], + } # Any other response type raises BadResponseFormat else: @@ -326,11 +338,7 @@ async def coro_request( ) # persistent connection - async def ws_send( - self, - method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], - params: Any, - ) -> RPCResponse: + async def ws_send(self, method: RPCEndpoint, params: Any) -> RPCResponse: provider = cast(PersistentConnectionProvider, self._provider) request_func = await provider.request_func( cast("AsyncWeb3", self.w3), @@ -340,11 +348,8 @@ async def ws_send( "Making request to open websocket connection - " f"uri: {provider.endpoint_uri}, method: {method}" ) - await request_func(method, params) - return await asyncio.wait_for( - self.ws_recv(), - timeout=provider.call_timeout, - ) + response = await request_func(method, params) + return await self._process_ws_response(response) async def ws_recv(self) -> Any: return await self._ws_recv_stream().__anext__() @@ -359,36 +364,53 @@ async def _ws_recv_stream(self) -> AsyncGenerator[RPCResponse, None]: "can listen to websocket recv streams." ) - response = json.loads( - await asyncio.wait_for( - self._provider.ws.recv(), - timeout=self._provider.call_timeout, - ) + cached_responses = len(self._request_processor._raw_response_cache.items()) + if cached_responses > 0: + async with async_lock( + self._provider._thread_pool, + self._provider._lock, + ): + self._provider.logger.debug( + f"{cached_responses} cached response(s) in raw response cache. " + f"Processing as FIFO ahead of any new responses from open " + f"socket connection." + ) + for ( + cache_key, + cached_response, + ) in self._request_processor._raw_response_cache.items(): + self._request_processor.pop_raw_response(cache_key) + yield await self._process_ws_response(cached_response) + else: + response = await self._provider._ws_recv() + yield await self._process_ws_response(response) + + async def _process_ws_response(self, response: RPCResponse) -> RPCResponse: + provider = cast(PersistentConnectionProvider, self._provider) + request_info = self._request_processor.get_request_information_for_response( + response ) - request_info = self._provider._get_request_information_for_response(response) if request_info is None: self.logger.debug("No cache key found for response, returning raw response") - yield response - + return response else: if request_info.method == "eth_subscribe" and "result" in response.keys(): # if response for the initial eth_subscribe request, which returns the # subscription id subscription_id = response["result"] cache_key = generate_cache_key(subscription_id) - if cache_key not in self._provider._async_response_processing_cache: + if cache_key not in self._request_processor._request_information_cache: # cache by subscription id in order to process each response for the # subscription as it comes in - self._provider.logger.debug( + provider.logger.debug( f"Caching eth_subscription info:\n " f"cache_key={cache_key},\n " f"request_info={request_info.__dict__}" ) - self._provider._async_response_processing_cache.cache( + self._request_processor._request_information_cache.cache( cache_key, request_info ) - # pipe response back through middleware response processors if len(request_info.middleware_response_processors) > 0: response = pipe(response, *request_info.middleware_response_processors) @@ -404,7 +426,7 @@ async def _ws_recv_stream(self) -> AsyncGenerator[RPCResponse, None]: error_formatters, null_formatters, ) - yield apply_result_formatters(result_formatters, partly_formatted_response) + return apply_result_formatters(result_formatters, partly_formatted_response) class _AsyncPersistentRecvStream: @@ -421,6 +443,7 @@ def __init__(self, manager: RequestManager, *args: Any, **kwargs: Any) -> None: def __aiter__(self) -> AsyncGenerator[RPCResponse, None]: while True: try: + # solely listen to the stream, no request id necessary return self.manager._ws_recv_stream() except ConnectionClosedOK: pass diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index e7509e1c6f..cdb1d707df 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -70,8 +70,10 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: if async_w3.provider.has_persistent_connection: # asynchronous response processing provider = cast("PersistentConnectionProvider", async_w3.provider) - provider._append_middleware_response_processor(_handle_async_response) - return None + provider._request_processor.append_middleware_response_processor( + _handle_async_response + ) + return response else: return _handle_async_response(response) diff --git a/web3/middleware/fixture.py b/web3/middleware/fixture.py index bc22399ffc..f2fc722eec 100644 --- a/web3/middleware/fixture.py +++ b/web3/middleware/fixture.py @@ -125,11 +125,13 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: if async_w3.provider.has_persistent_connection: provider = cast("PersistentConnectionProvider", async_w3.provider) - await make_request(method, params) - provider._append_middleware_response_processor( + response = await make_request(method, params) + provider._request_processor.append_middleware_response_processor( + # processed asynchronously later but need to pass the actual + # response to the next middleware lambda _: {"result": result} ) - return None + return response else: return {"result": result} else: @@ -169,11 +171,13 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: if async_w3.provider.has_persistent_connection: provider = cast("PersistentConnectionProvider", async_w3.provider) - await make_request(method, params) - provider._append_middleware_response_processor( + response = await make_request(method, params) + provider._request_processor.append_middleware_response_processor( + # processed asynchronously later but need to pass the actual + # response to the next middleware lambda _: error_response ) - return None + return response else: return cast(RPCResponse, error_response) else: diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 6fd3c82b3c..ef84050474 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -167,14 +167,14 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: if async_w3.provider.has_persistent_connection: # asynchronous response processing provider = cast("PersistentConnectionProvider", async_w3.provider) - provider._append_middleware_response_processor( + provider._request_processor.append_middleware_response_processor( _apply_response_formatters( method, formatters["result_formatters"], formatters["error_formatters"], ) ) - return None + return response else: return _apply_response_formatters( method, diff --git a/web3/module.py b/web3/module.py index 776fc21824..41ad2ca156 100644 --- a/web3/module.py +++ b/web3/module.py @@ -7,6 +7,7 @@ Optional, TypeVar, Union, + cast, ) from eth_abi.codec import ( @@ -29,6 +30,7 @@ PersistentConnectionProvider, ) from web3.types import ( + RPCEndpoint, RPCResponse, ) @@ -93,16 +95,18 @@ async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter if isinstance(async_w3.provider, PersistentConnectionProvider): # TODO: The typing does not seem to be correct for response_formatters. # For now, keep the expected typing but ignore it here. - cache_key = async_w3.provider._cache_request_information( + provider = async_w3.provider + cache_key = provider._request_processor.cache_request_information( method_str, params, response_formatters # type: ignore ) try: + method_str = cast(RPCEndpoint, method_str) return await async_w3.manager.ws_send(method_str, params) except Exception as e: - if async_w3.provider._async_response_processing_cache.get_cache_entry( - cache_key - ): - async_w3.provider._pop_cached_request_information(cache_key) + if cache_key in provider._request_processor._request_information_cache: + provider._request_processor.pop_cached_request_information( + cache_key + ) raise e else: ( diff --git a/web3/providers/__init__.py b/web3/providers/__init__.py index 0ad0d62499..883b88060e 100644 --- a/web3/providers/__init__.py +++ b/web3/providers/__init__.py @@ -11,9 +11,6 @@ from .ipc import ( IPCProvider, ) -from .persistent import ( - PersistentConnectionProvider, -) from .rpc import ( HTTPProvider, ) @@ -21,6 +18,9 @@ WebsocketProvider, WebsocketProviderV2, ) +from .persistent import ( + PersistentConnectionProvider, +) from .auto import ( AutoProvider, ) diff --git a/web3/providers/persistent.py b/web3/providers/persistent.py index 84d823cdf1..74896cfea8 100644 --- a/web3/providers/persistent.py +++ b/web3/providers/persistent.py @@ -1,35 +1,28 @@ from abc import ( ABC, ) -from copy import ( - copy, +from concurrent.futures import ( + ThreadPoolExecutor, ) import logging +import threading from typing import ( - Any, - Callable, Optional, - Tuple, ) from websockets.legacy.client import ( WebSocketClientProtocol, ) -from web3._utils.caching import ( - RequestInformation, - generate_cache_key, -) from web3.providers.async_base import ( AsyncJSONBaseProvider, ) +from web3.providers.websocket.request_processor import ( + RequestProcessor, +) from web3.types import ( - RPCEndpoint, RPCResponse, ) -from web3.utils import ( - SimpleCache, -) DEFAULT_PERSISTENT_CONNECTION_TIMEOUT = 20 @@ -37,7 +30,11 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC): logger = logging.getLogger("web3.providers.PersistentConnectionProvider") has_persistent_connection = True - ws: Optional[WebSocketClientProtocol] = None + + _ws: Optional[WebSocketClientProtocol] = None + _request_processor: RequestProcessor + _thread_pool: ThreadPoolExecutor = ThreadPoolExecutor() + _lock: threading.Lock = threading.Lock() def __init__( self, @@ -47,8 +44,9 @@ def __init__( ) -> None: super().__init__() self.endpoint_uri = endpoint_uri - self._async_response_processing_cache: SimpleCache = SimpleCache( - request_cache_size + self._request_processor = RequestProcessor( + self, + request_info_cache_size=request_cache_size, ) self.call_timeout = call_timeout @@ -58,117 +56,5 @@ async def connect(self) -> None: async def disconnect(self) -> None: raise NotImplementedError("Must be implemented by subclasses") - def _cache_request_information( - self, - method: RPCEndpoint, - params: Any, - response_formatters: Tuple[Callable[..., Any], ...], - ) -> str: - # copy the request counter and find the next request id without incrementing - # since this is done when / if the request is successfully sent - request_id = next(copy(self.request_counter)) - cache_key = generate_cache_key(request_id) - - self._bump_cache_if_key_present(cache_key, request_id) - - request_info = RequestInformation(method, params, response_formatters) - self.logger.debug( - f"Caching request info:\n request_id={request_id},\n" - f" cache_key={cache_key},\n request_info={request_info.__dict__}" - ) - self._async_response_processing_cache.cache( - cache_key, - request_info, - ) - return cache_key - - def _bump_cache_if_key_present(self, cache_key: str, request_id: int) -> None: - """ - If the cache key is present in the cache, bump the cache key and request id - by one to make room for the new request. This behavior is necessary when a - request is made but inner requests, say to `eth_estimateGas` if the `gas` is - missing, are made before the original request is sent. - """ - if cache_key in self._async_response_processing_cache: - original_request_info = ( - self._async_response_processing_cache.get_cache_entry(cache_key) - ) - bump = generate_cache_key(request_id + 1) - - # recursively bump the cache if the new key is also present - self._bump_cache_if_key_present(bump, request_id + 1) - - self.logger.debug( - f"Caching internal request. Bumping original request in cache:\n" - f" request_id=[{request_id}] -> [{request_id + 1}],\n" - f" cache_key=[{cache_key}] -> [{bump}],\n" - f" request_info={original_request_info.__dict__}" - ) - self._async_response_processing_cache.cache(bump, original_request_info) - - def _pop_cached_request_information( - self, cache_key: str - ) -> Optional[RequestInformation]: - request_info = self._async_response_processing_cache.pop(cache_key) - if request_info is not None: - self.logger.debug( - f"Request info popped from cache:\n" - f" cache_key={cache_key},\n request_info={request_info.__dict__}" - ) - return request_info - - def _get_request_information_for_response( - self, - response: RPCResponse, - ) -> RequestInformation: - if "method" in response and response["method"] == "eth_subscription": - if "params" not in response: - raise ValueError("Subscription response must have params field") - if "subscription" not in response["params"]: - raise ValueError( - "Subscription response params must have subscription field" - ) - - # retrieve the request info from the cache using the subscription id - cache_key = generate_cache_key(response["params"]["subscription"]) - request_info = ( - # don't pop the request info from the cache, since we need to keep it - # to process future subscription responses - # i.e. subscription request information remains in the cache - self._async_response_processing_cache.get_cache_entry(cache_key) - ) - - else: - # retrieve the request info from the cache using the request id - cache_key = generate_cache_key(response["id"]) - request_info = ( - # pop the request info from the cache since we don't need to keep it - # this keeps the cache size bounded - self._pop_cached_request_information(cache_key) - ) - if ( - request_info is not None - and request_info.method == "eth_unsubscribe" - and response.get("result") is True - ): - # if successful unsubscribe request, remove the subscription request - # information from the cache since it is no longer needed - subscription_id = request_info.params[0] - subscribe_cache_key = generate_cache_key(subscription_id) - self._pop_cached_request_information(subscribe_cache_key) - - return request_info - - def _append_middleware_response_processor( - self, - middleware_response_processor: Callable[..., Any], - ) -> None: - request_id = next(copy(self.request_counter)) - 1 - cache_key = generate_cache_key(request_id) - current_request_cached_info: RequestInformation = ( - self._async_response_processing_cache.get_cache_entry(cache_key) - ) - if current_request_cached_info: - current_request_cached_info.middleware_response_processors.append( - middleware_response_processor - ) + async def _ws_recv(self) -> RPCResponse: + raise NotImplementedError("Must be implemented by subclasses") diff --git a/web3/providers/websocket/__init__.py b/web3/providers/websocket/__init__.py index daf0ad634a..1afa41f867 100644 --- a/web3/providers/websocket/__init__.py +++ b/web3/providers/websocket/__init__.py @@ -3,6 +3,9 @@ RESTRICTED_WEBSOCKET_KWARGS, WebsocketProvider, ) +from .websocket_connection import ( + WebsocketConnection, +) from .websocket_v2 import ( WebsocketProviderV2, ) diff --git a/web3/providers/websocket/request_processor.py b/web3/providers/websocket/request_processor.py new file mode 100644 index 0000000000..1102d15267 --- /dev/null +++ b/web3/providers/websocket/request_processor.py @@ -0,0 +1,195 @@ +from copy import ( + copy, +) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Tuple, +) +from uuid import ( + uuid4, +) + +from web3._utils.caching import ( + RequestInformation, + generate_cache_key, +) +from web3.types import ( + RPCEndpoint, + RPCResponse, +) +from web3.utils import ( + SimpleCache, +) + +if TYPE_CHECKING: + from web3.providers.persistent import ( + PersistentConnectionProvider, + ) + + +class RequestProcessor: + _request_information_cache: SimpleCache + + def __init__( + self, + provider: "PersistentConnectionProvider", + request_info_cache_size: int = 500, + ) -> None: + self._provider = provider + + self._request_information_cache = SimpleCache(request_info_cache_size) + self._raw_response_cache = SimpleCache(500) + + # request information cache + + def cache_request_information( + self, + method: RPCEndpoint, + params: Any, + response_formatters: Tuple[Callable[..., Any], ...], + ) -> str: + # copy the request counter and find the next request id without incrementing + # since this is done when / if the request is successfully sent + request_id = next(copy(self._provider.request_counter)) + cache_key = generate_cache_key(request_id) + + self._bump_cache_if_key_present(cache_key, request_id) + + request_info = RequestInformation( + method, + params, + response_formatters, + ) + self._provider.logger.debug( + f"Caching request info:\n request_id={request_id},\n" + f" cache_key={cache_key},\n request_info={request_info.__dict__}" + ) + self._request_information_cache.cache( + cache_key, + request_info, + ) + return cache_key + + def _bump_cache_if_key_present(self, cache_key: str, request_id: int) -> None: + """ + If the cache key is present in the cache, bump the cache key and request id + by one to make room for the new request. This behavior is necessary when a + request is made but inner requests, say to `eth_estimateGas` if the `gas` is + missing, are made before the original request is sent. + """ + if cache_key in self._request_information_cache: + original_request_info = self._request_information_cache.get_cache_entry( + cache_key + ) + bump = generate_cache_key(request_id + 1) + + # recursively bump the cache if the new key is also present + self._bump_cache_if_key_present(bump, request_id + 1) + + self._provider.logger.debug( + f"Caching internal request. Bumping original request in cache:\n" + f" request_id=[{request_id}] -> [{request_id + 1}],\n" + f" cache_key=[{cache_key}] -> [{bump}],\n" + f" request_info={original_request_info.__dict__}" + ) + self._request_information_cache.cache(bump, original_request_info) + + def pop_cached_request_information( + self, cache_key: str + ) -> Optional[RequestInformation]: + request_info = self._request_information_cache.pop(cache_key) + if request_info is not None: + self._provider.logger.debug( + f"Request info popped from cache:\n" + f" cache_key={cache_key},\n request_info={request_info.__dict__}" + ) + return request_info + + def get_request_information_for_response( + self, + response: RPCResponse, + ) -> RequestInformation: + if "method" in response and response["method"] == "eth_subscription": + if "params" not in response: + raise ValueError("Subscription response must have params field") + if "subscription" not in response["params"]: + raise ValueError( + "Subscription response params must have subscription field" + ) + + # retrieve the request info from the cache using the subscription id + cache_key = generate_cache_key(response["params"]["subscription"]) + request_info = ( + # don't pop the request info from the cache, since we need to keep it + # to process future subscription responses + # i.e. subscription request information remains in the cache + self._request_information_cache.get_cache_entry(cache_key) + ) + + else: + # retrieve the request info from the cache using the request id + cache_key = generate_cache_key(response["id"]) + request_info = ( + # pop the request info from the cache since we don't need to keep it + # this keeps the cache size bounded + self.pop_cached_request_information(cache_key) + ) + if ( + request_info is not None + and request_info.method == "eth_unsubscribe" + and response.get("result") is True + ): + # if successful unsubscribe request, remove the subscription request + # information from the cache since it is no longer needed + subscription_id = request_info.params[0] + subscribe_cache_key = generate_cache_key(subscription_id) + self.pop_cached_request_information(subscribe_cache_key) + + return request_info + + def append_middleware_response_processor( + self, + middleware_response_processor: Callable[..., Any], + ) -> None: + request_id = next(copy(self._provider.request_counter)) - 1 + cache_key = generate_cache_key(request_id) + current_request_cached_info: RequestInformation = ( + self._request_information_cache.get_cache_entry(cache_key) + ) + if current_request_cached_info: + current_request_cached_info.middleware_response_processors.append( + middleware_response_processor + ) + + # raw response cache + + def cache_raw_response(self, raw_response: Any) -> None: + # get id or generate a uuid if not present (i.e. subscription response) + response_id = raw_response.get("id", f"sub-{uuid4()}") + cache_key = generate_cache_key(response_id) + self._provider.logger.debug( + f"Caching raw response:\n response_id={response_id},\n" + f" cache_key={cache_key},\n raw_response={raw_response}" + ) + self._raw_response_cache.cache(cache_key, raw_response) + + def pop_raw_response(self, cache_key: str) -> Any: + raw_response = self._raw_response_cache.pop(cache_key) + self._provider.logger.debug( + f"Cached response processed and popped from cache:\n" + f" cache_key={cache_key},\n" + f" raw_response={raw_response}" + ) + + # request processor class methods + + def clear_caches(self) -> None: + """ + Clear the request information and raw response caches. + """ + + self._request_information_cache.clear() + self._raw_response_cache.clear() diff --git a/web3/providers/websocket/websocket_connection.py b/web3/providers/websocket/websocket_connection.py new file mode 100644 index 0000000000..a8f4e70eb7 --- /dev/null +++ b/web3/providers/websocket/websocket_connection.py @@ -0,0 +1,36 @@ +from typing import ( + TYPE_CHECKING, + Any, +) + +from web3.types import ( + RPCEndpoint, + RPCResponse, +) + +if TYPE_CHECKING: + from web3.main import ( # noqa: F401 + _PersistentConnectionWeb3, + ) + from web3.manager import ( # noqa: F401 + _AsyncPersistentRecvStream, + ) + + +class WebsocketConnection: + """ + A class that houses the public API for interacting with the websocket connection + via a `_PersistentConnectionWeb3` instance. + """ + + def __init__(self, w3: "_PersistentConnectionWeb3"): + self._w3 = w3 + + async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse: + return await self._w3.manager.ws_send(method, params) + + async def recv(self) -> Any: + return await self._w3.manager.ws_recv() + + def listen_to_websocket(self) -> "_AsyncPersistentRecvStream": + return self._w3.manager.persistent_recv_stream() diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index 213a761d25..6bc513e990 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import os from typing import ( @@ -21,6 +22,12 @@ WebSocketException, ) +from web3._utils.async_caching import ( + async_lock, +) +from web3._utils.caching import ( + generate_cache_key, +) from web3.exceptions import ( ProviderConnectionError, Web3ValidationError, @@ -30,6 +37,8 @@ ) from web3.types import ( RPCEndpoint, + RPCId, + RPCResponse, ) DEFAULT_PING_INTERVAL = 30 # 30 seconds @@ -91,11 +100,11 @@ def __str__(self) -> str: return f"Websocket connection: {self.endpoint_uri}" async def is_connected(self, show_traceback: bool = False) -> bool: - if not self.ws: + if not self._ws: return False try: - await self.ws.pong() + await self._ws.pong() return True except WebSocketException as e: @@ -113,7 +122,7 @@ async def connect(self) -> None: while _connection_attempts != self._max_connection_retries: try: _connection_attempts += 1 - self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs) + self._ws = await connect(self.endpoint_uri, **self.websocket_kwargs) break except WebSocketException as e: if _connection_attempts == self._max_connection_retries: @@ -130,16 +139,61 @@ async def connect(self) -> None: _backoff_time *= _backoff_rate_change async def disconnect(self) -> None: - await self.ws.close() - self.ws = None + await self._ws.close() + self._ws = None - # clear the provider request cache after disconnecting - self._async_response_processing_cache.clear() + # clear the request information cache after disconnecting + self._request_processor.clear_caches() + self.logger.debug( + f'Successfully disconnected from endpoint: "{self.endpoint_uri}" ' + "and the request processor transient caches were cleared." + ) - async def make_request(self, method: RPCEndpoint, params: Any) -> None: + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: request_data = self.encode_rpc_request(method, params) - if self.ws is None: + if self._ws is None: await self.connect() - await asyncio.wait_for(self.ws.send(request_data), timeout=self.call_timeout) + await asyncio.wait_for(self._ws.send(request_data), timeout=self.call_timeout) + + current_request_id = json.loads(request_data)["id"] + + response = await self._ws_recv() + response_id = response.get("id") + + if response_id != current_request_id: + request_cache_key = generate_cache_key(current_request_id) + if request_cache_key in self._request_processor._raw_response_cache: + async with async_lock(self._thread_pool, self._lock): + # if response is already cached, pop it from the cache + response = self._request_processor.pop_raw_response( + request_cache_key + ) + else: + async with async_lock(self._thread_pool, self._lock): + # cache response + self._request_processor.cache_raw_response(response) + response = await asyncio.wait_for( + self._get_response_for_request_id(current_request_id), + self.call_timeout, + ) + return response + + async def _get_response_for_request_id(self, request_id: RPCId) -> RPCResponse: + response = await self._ws_recv() + response_id = response.get("id") + + while response_id != request_id: + response = await self._ws_recv() + response_id = response.get("id") + if response_id != request_id: + self._request_processor.cache_raw_response( + response, + ) + return response + + async def _ws_recv(self) -> RPCResponse: + return json.loads( + await asyncio.wait_for(self._ws.recv(), timeout=self.call_timeout) + ) diff --git a/web3/types.py b/web3/types.py index 5e234434ff..ef154c06d9 100644 --- a/web3/types.py +++ b/web3/types.py @@ -281,9 +281,12 @@ class GethSyncingSubscriptionResponse(SubscriptionResponse): result: GethSyncingSubscriptionResult +RPCId = Optional[Union[int, str]] + + class RPCResponse(TypedDict, total=False): error: Union[RPCError, str] - id: Optional[Union[int, str]] + id: RPCId jsonrpc: Literal["2.0"] result: Any diff --git a/web3/utils/caching.py b/web3/utils/caching.py index 9ba290b071..c315e90844 100644 --- a/web3/utils/caching.py +++ b/web3/utils/caching.py @@ -4,6 +4,7 @@ from typing import ( Any, Dict, + List, Optional, Tuple, ) @@ -38,8 +39,8 @@ def get_cache_entry(self, key: str) -> Optional[Any]: def clear(self) -> None: self._data.clear() - def items(self) -> Dict[str, Any]: - return self._data + def items(self) -> List[Tuple[str, Any]]: + return list(self._data.items()) def pop(self, key: str) -> Optional[Any]: if key not in self._data: