diff --git a/tests/core/middleware/test_simple_cache_middleware.py b/tests/core/middleware/test_simple_cache_middleware.py index 11ba2b8026..a68dbc79e2 100644 --- a/tests/core/middleware/test_simple_cache_middleware.py +++ b/tests/core/middleware/test_simple_cache_middleware.py @@ -5,6 +5,7 @@ from web3 import Web3 from web3._utils.caching import ( + SimpleCache, generate_cache_key, ) from web3.middleware import ( @@ -51,13 +52,33 @@ def w3(w3_base, result_generator_middleware): return w3_base -def test_simple_cache_middleware_pulls_from_cache(w3): - def cache_class(): - return { - generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"): { - "result": "value-a" - }, - } +def dict_cache_class_return_value_a(): + # test dictionary-based cache + return { + generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"): { + "result": "value-a" + }, + } + + +def simple_cache_class_return_value_a(): + # test `SimpleCache` class cache + _cache = SimpleCache() + _cache.cache( + generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"), + {"result": "value-a"}, + ) + return _cache + + +@pytest.mark.parametrize( + "cache_class", + ( + dict_cache_class_return_value_a, + simple_cache_class_return_value_a, + ), +) +def test_simple_cache_middleware_pulls_from_cache(w3, cache_class): w3.middleware_onion.add( construct_simple_cache_middleware( @@ -69,10 +90,11 @@ def cache_class(): assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" -def test_simple_cache_middleware_populates_cache(w3): +@pytest.mark.parametrize("cache_class", (dict, SimpleCache)) +def test_simple_cache_middleware_populates_cache(w3, cache_class): w3.middleware_onion.add( construct_simple_cache_middleware( - cache_class=dict, + cache_class=cache_class, rpc_whitelist={RPCEndpoint("fake_endpoint")}, ) ) @@ -83,7 +105,8 @@ def test_simple_cache_middleware_populates_cache(w3): assert w3.manager.request_blocking("fake_endpoint", [1]) != result -def test_simple_cache_middleware_does_not_cache_none_responses(w3_base): +@pytest.mark.parametrize("cache_class", (dict, SimpleCache)) +def test_simple_cache_middleware_does_not_cache_none_responses(w3_base, cache_class): counter = itertools.count() w3 = w3_base @@ -101,7 +124,7 @@ def result_cb(_method, _params): w3.middleware_onion.add( construct_simple_cache_middleware( - cache_class=dict, + cache_class=cache_class, rpc_whitelist={RPCEndpoint("fake_endpoint")}, ) ) @@ -112,7 +135,8 @@ def result_cb(_method, _params): assert next(counter) == 2 -def test_simple_cache_middleware_does_not_cache_error_responses(w3_base): +@pytest.mark.parametrize("cache_class", (dict, SimpleCache)) +def test_simple_cache_middleware_does_not_cache_error_responses(w3_base, cache_class): w3 = w3_base w3.middleware_onion.add( construct_error_generator_middleware( @@ -124,7 +148,7 @@ def test_simple_cache_middleware_does_not_cache_error_responses(w3_base): w3.middleware_onion.add( construct_simple_cache_middleware( - cache_class=dict, + cache_class=cache_class, rpc_whitelist={RPCEndpoint("fake_endpoint")}, ) ) @@ -137,10 +161,14 @@ def test_simple_cache_middleware_does_not_cache_error_responses(w3_base): assert str(err_a) != str(err_b) -def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): +@pytest.mark.parametrize("cache_class", (dict, SimpleCache)) +def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist( + w3, + cache_class, +): w3.middleware_onion.add( construct_simple_cache_middleware( - cache_class=dict, + cache_class=cache_class, rpc_whitelist={RPCEndpoint("fake_endpoint")}, ) ) @@ -156,7 +184,7 @@ def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): async def _async_simple_cache_middleware_for_testing(make_request, async_w3): middleware = await async_construct_simple_cache_middleware( - cache_class=dict, + cache_class=SimpleCache, rpc_whitelist={RPCEndpoint("fake_endpoint")}, ) return await middleware(make_request, async_w3) @@ -173,17 +201,17 @@ def async_w3(): @pytest.mark.asyncio -async def test_async_simple_cache_middleware_pulls_from_cache(async_w3): +@pytest.mark.parametrize( + "cache_class", + ( + dict_cache_class_return_value_a, + simple_cache_class_return_value_a, + ), +) +async def test_async_simple_cache_middleware_pulls_from_cache(async_w3, cache_class): # remove the pre-loaded simple cache middleware to replace with test-specific: async_w3.middleware_onion.remove("simple_cache") - def cache_class(): - return { - generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"): { - "result": "value-a" - }, - } - async def _properly_awaited_middleware(make_request, _async_w3): middleware = await async_construct_simple_cache_middleware( cache_class=cache_class,