Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for async_simple_cache_middleware #2579

Merged
merged 4 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,6 @@ Supported Middleware
- :meth:`Gas Price Strategy <web3.middleware.gas_price_strategy_middleware>`
- :meth:`Buffered Gas Estimate Middleware <web3.middleware.buffered_gas_estimate_middleware>`
- :meth:`Stalecheck Middleware <web3.middleware.make_stalecheck_middleware>`
- :meth:`Validation middleware <web3.middleware.validation>`
- :meth:`Validation Middleware <web3.middleware.validation>`
- :ref:`Geth POA Middleware <geth-poa>`
- :meth:`Simple Cache Middleware <web3.middleware.simple_cache_middleware>`
1 change: 1 addition & 0 deletions newsfragments/2579.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Async support for caching certain methods via ``async_simple_cache_middleware`` as well as constructing custom async caching middleware via ``async_construct_simple_cache_middleware``.
220 changes: 197 additions & 23 deletions tests/core/middleware/test_simple_cache_middleware.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
import itertools
import pytest
import threading
import uuid

from web3 import Web3
from web3._utils.caching import (
SimpleCache,
generate_cache_key,
)
from web3.middleware import (
construct_error_generator_middleware,
construct_result_generator_middleware,
construct_simple_cache_middleware,
)
from web3.middleware.async_cache import (
async_construct_simple_cache_middleware,
)
from web3.middleware.fixture import (
async_construct_error_generator_middleware,
async_construct_result_generator_middleware,
)
from web3.providers.base import (
BaseProvider,
)
from web3.providers.eth_tester import (
AsyncEthereumTesterProvider,
)
from web3.types import (
RPCEndpoint,
)


@pytest.fixture
Expand All @@ -25,8 +40,8 @@ def w3_base():
def result_generator_middleware():
return construct_result_generator_middleware(
{
"fake_endpoint": lambda *_: str(uuid.uuid4()),
"not_whitelisted": lambda *_: str(uuid.uuid4()),
RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()),
RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()),
}
)

Expand All @@ -37,27 +52,50 @@ def w3(w3_base, result_generator_middleware):
return w3_base


def test_simple_cache_middleware_pulls_from_cache(w3):
def cache_class():
return {
generate_cache_key(("fake_endpoint", [1])): {"result": "value-a"},
}
def dict_cache_class_return_value_a():
# test dictionary-based cache
return {
generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"): {
"result": "value-a"
},
}


def simple_cache_class_return_value_a():
# test `SimpleCache` class cache
_cache = SimpleCache()
_cache.cache(
generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"),
{"result": "value-a"},
)
return _cache


@pytest.mark.parametrize(
"cache_class",
(
dict_cache_class_return_value_a,
simple_cache_class_return_value_a,
),
)
def test_simple_cache_middleware_pulls_from_cache(w3, cache_class):

w3.middleware_onion.add(
construct_simple_cache_middleware(
cache_class=cache_class,
rpc_whitelist={"fake_endpoint"},
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
)

assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a"


def test_simple_cache_middleware_populates_cache(w3):
@pytest.mark.parametrize("cache_class", (dict, SimpleCache))
def test_simple_cache_middleware_populates_cache(w3, cache_class):
w3.middleware_onion.add(
construct_simple_cache_middleware(
cache_class=dict,
rpc_whitelist={"fake_endpoint"},
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
)

Expand All @@ -67,26 +105,27 @@ def test_simple_cache_middleware_populates_cache(w3):
assert w3.manager.request_blocking("fake_endpoint", [1]) != result


def test_simple_cache_middleware_does_not_cache_none_responses(w3_base):
@pytest.mark.parametrize("cache_class", (dict, SimpleCache))
def test_simple_cache_middleware_does_not_cache_none_responses(w3_base, cache_class):
counter = itertools.count()
w3 = w3_base

def result_cb(method, params):
def result_cb(_method, _params):
next(counter)
return None

w3.middleware_onion.add(
construct_result_generator_middleware(
{
"fake_endpoint": result_cb,
RPCEndpoint("fake_endpoint"): result_cb,
}
)
)

w3.middleware_onion.add(
construct_simple_cache_middleware(
cache_class=dict,
rpc_whitelist={"fake_endpoint"},
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
)

Expand All @@ -96,20 +135,21 @@ def result_cb(method, params):
assert next(counter) == 2


def test_simple_cache_middleware_does_not_cache_error_responses(w3_base):
@pytest.mark.parametrize("cache_class", (dict, SimpleCache))
def test_simple_cache_middleware_does_not_cache_error_responses(w3_base, cache_class):
w3 = w3_base
w3.middleware_onion.add(
construct_error_generator_middleware(
{
"fake_endpoint": lambda *_: f"msg-{uuid.uuid4()}",
RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}",
}
)
)

w3.middleware_onion.add(
construct_simple_cache_middleware(
cache_class=dict,
rpc_whitelist={"fake_endpoint"},
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
)

Expand All @@ -121,15 +161,149 @@ def test_simple_cache_middleware_does_not_cache_error_responses(w3_base):
assert str(err_a) != str(err_b)


def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3):
@pytest.mark.parametrize("cache_class", (dict, SimpleCache))
def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(
w3,
cache_class,
):
w3.middleware_onion.add(
construct_simple_cache_middleware(
cache_class=dict,
rpc_whitelist={"fake_endpoint"},
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
)

result_a = w3.manager.request_blocking("not_whitelisted", [])
result_b = w3.manager.request_blocking("not_whitelisted", [])

assert result_a != result_b


# -- async -- #


async def _async_simple_cache_middleware_for_testing(make_request, async_w3):
middleware = await async_construct_simple_cache_middleware(
cache_class=SimpleCache,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
return await middleware(make_request, async_w3)


@pytest.fixture
def async_w3():
return Web3(
provider=AsyncEthereumTesterProvider(),
middlewares=[
(_async_simple_cache_middleware_for_testing, "simple_cache"),
],
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"cache_class",
(
dict_cache_class_return_value_a,
simple_cache_class_return_value_a,
),
)
async def test_async_simple_cache_middleware_pulls_from_cache(async_w3, cache_class):
async def _properly_awaited_middleware(make_request, _async_w3):
middleware = await async_construct_simple_cache_middleware(
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
return await middleware(make_request, _async_w3)

async_w3.middleware_onion.inject(
_properly_awaited_middleware,
layer=0,
)

_result = await async_w3.manager.coro_request("fake_endpoint", [1])
assert _result == "value-a"


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_populates_cache(async_w3):
async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()),
}
),
"result_generator",
layer=0,
)

result = await async_w3.manager.coro_request("fake_endpoint", [])

_empty_params = await async_w3.manager.coro_request("fake_endpoint", [])
_non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1])

assert _empty_params == result
assert _non_empty_params != result


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3):
counter = itertools.count()

def result_cb(_method, _params):
next(counter)
return None

async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("fake_endpoint"): result_cb,
},
),
"result_generator",
layer=0,
)

await async_w3.manager.coro_request("fake_endpoint", [])
await async_w3.manager.coro_request("fake_endpoint", [])

assert next(counter) == 2


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3):
async_w3.middleware_onion.inject(
await async_construct_error_generator_middleware(
{
RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}",
}
),
"error_generator",
layer=0,
)

with pytest.raises(ValueError) as err_a:
await async_w3.manager.coro_request("fake_endpoint", [])
with pytest.raises(ValueError) as err_b:
await async_w3.manager.coro_request("fake_endpoint", [])

assert str(err_a) != str(err_b)


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints(
async_w3,
):
async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()),
}
),
layer=0,
)

result_a = await async_w3.manager.coro_request("not_whitelisted", [])
result_b = await async_w3.manager.coro_request("not_whitelisted", [])

assert result_a != result_b
35 changes: 35 additions & 0 deletions tests/core/utilities/test_caching_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
from concurrent.futures import (
ThreadPoolExecutor,
)
import pytest
import threading

from web3._utils.async_caching import (
async_lock,
)

# --- async -- #


@pytest.mark.asyncio
async def test_async_lock_releases_if_a_task_is_cancelled():
# inspired by issue #2693
# Note: this test will raise a `TimeoutError` if `request.async_lock` is not
# applied correctly

_thread_pool = ThreadPoolExecutor(max_workers=1)
_lock = threading.Lock()

async def _utilize_async_lock():
async with async_lock(_thread_pool, _lock):
await asyncio.sleep(0.2)

asyncio.create_task(_utilize_async_lock())

inner = asyncio.create_task(_utilize_async_lock())
await asyncio.sleep(0.1)
inner.cancel()

outer = asyncio.wait_for(_utilize_async_lock(), 2)
await outer
Loading