From 205186e1a8e2f0848b15702238827145dca282d9 Mon Sep 17 00:00:00 2001 From: fselmo Date: Wed, 28 Aug 2024 11:14:07 -0600 Subject: [PATCH] Changes from comments on PR #3463: - ``_build_tkey()`` -> ``_build_name()`` - Add test for unnamed ``Web3Middleware`` added to onion, not just ``Web3MiddlewareBuilder.build()`` (functools.curry) references. --- tests/core/middleware/test_middleware.py | 34 ++++++++++++++++++------ web3/datastructures.py | 20 +++++++------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/tests/core/middleware/test_middleware.py b/tests/core/middleware/test_middleware.py index 40031d0ab0..39a2365879 100644 --- a/tests/core/middleware/test_middleware.py +++ b/tests/core/middleware/test_middleware.py @@ -12,18 +12,26 @@ ) +class TestMiddleware(Web3Middleware): + def response_processor(self, method, response): + if method == "eth_blockNumber": + response["result"] = 1234 + + return response + + +class TestMiddleware2(Web3Middleware): + def response_processor(self, method, response): + if method == "eth_blockNumber": + response["result"] = 4321 + + return response + + def test_middleware_class_eq_magic_method(): w3_a = Web3() w3_b = Web3() - class TestMiddleware(Web3Middleware): - def request_processor(self, method, params): - return 1234 - - class TestMiddleware2(Web3Middleware): - def request_processor(self, method, params): - return 4321 - mw1w3_a = TestMiddleware(w3_a) assert mw1w3_a is not None assert mw1w3_a != "" @@ -61,3 +69,13 @@ def test_unnamed_middleware_are_given_unique_keys(w3): with pytest.raises(Web3ValueError): # adding the same middleware again should cause an error w3.middleware_onion.add(request_formatting_middleware) + + +def test_unnamed_class_middleware_are_given_unique_keys(w3): + w3.middleware_onion.add(TestMiddleware) + w3.middleware_onion.add(TestMiddleware2) + assert isinstance(w3.eth.block_number, int) + + with pytest.raises(Web3ValueError): + # adding the same middleware again should cause an error + w3.middleware_onion.add(TestMiddleware) diff --git a/web3/datastructures.py b/web3/datastructures.py index e9b07348b8..90932aa489 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -182,7 +182,7 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None: if name is None: name = cast(TKey, element) - name = self._build_tkey(name) + name = self._build_name(name) if name in self._queue: if name is element: @@ -219,7 +219,7 @@ def inject( if name is None: name = cast(TKey, element) - name = self._build_tkey(name) + name = self._build_name(name) self._queue.move_to_end(name, last=False) elif layer == len(self._queue): @@ -233,7 +233,7 @@ def clear(self) -> None: self._queue.clear() def replace(self, old: TKey, new: TKey) -> TValue: - old_name = self._build_tkey(old) + old_name = self._build_name(old) if old_name not in self._queue: raise Web3ValueError( @@ -249,7 +249,7 @@ def replace(self, old: TKey, new: TKey) -> TValue: return to_be_replaced @staticmethod - def _build_tkey(value: TKey) -> TKey: + def _build_name(value: TKey) -> TKey: try: value.__hash__() return value @@ -261,12 +261,12 @@ def _build_tkey(value: TKey) -> TKey: ) # This will either be ``Web3Middleware`` class or the ``build`` method of a # ``Web3MiddlewareBuilder``. Instantiate with empty ``Web3`` and use a - # unique identifier with the ``__hash__()`` as the TKey. + # unique identifier with the ``__hash__()`` as the name. v = value(None) return cast(TKey, f"{v.__class__}<{v.__hash__()}>") def remove(self, old: TKey) -> None: - old_name = self._build_tkey(old) + old_name = self._build_name(old) if old_name not in self._queue: raise Web3ValueError("You can only remove something that has been added") del self._queue[old_name] @@ -280,8 +280,8 @@ def middleware(self) -> Sequence[Any]: return [(val, key) for key, val in reversed(self._queue.items())] def _replace_with_new_name(self, old: TKey, new: TKey) -> None: - old_name = self._build_tkey(old) - new_name = self._build_tkey(new) + old_name = self._build_name(old) + new_name = self._build_name(new) self._queue[new_name] = new found_old = False @@ -303,11 +303,11 @@ def __add__(self, other: Any) -> "NamedElementOnion[TKey, TValue]": return NamedElementOnion(cast(List[Any], combined.items())) def __contains__(self, element: Any) -> bool: - element_name = self._build_tkey(element) + element_name = self._build_name(element) return element_name in self._queue def __getitem__(self, element: TKey) -> TValue: - element_name = self._build_tkey(element) + element_name = self._build_name(element) return self._queue[element_name] def __len__(self) -> int: