Skip to content

Commit

Permalink
Improve compatibility with pytest-asyncio 0.23
Browse files Browse the repository at this point in the history
  • Loading branch information
bmerry committed Mar 23, 2024
1 parent 393e7b5 commit 448a203
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
7 changes: 6 additions & 1 deletion src/async_solipsism/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .exceptions import SolipsismError


__all__ = ('EventLoop', 'stream_pairs')
__all__ = ('EventLoop', 'EventLoopPolicy', 'stream_pairs')


class EventLoop(asyncio.selector_events.BaseSelectorEventLoop):
Expand Down Expand Up @@ -197,6 +197,11 @@ def _stop_serving(self, sock):
super()._stop_serving(sock)


class EventLoopPolicy(asyncio.DefaultEventLoopPolicy):
def new_event_loop(self) -> EventLoop:
return EventLoop()


async def stream_pairs(capacity=None):
sock1, sock2 = _socket.socketpair(capacity=capacity)
streams1 = await asyncio.open_connection(sock=sock1)
Expand Down
48 changes: 31 additions & 17 deletions test/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@


@pytest.fixture
def event_loop():
loop = async_solipsism.EventLoop()
yield loop
loop.close()
def event_loop_policy():
return async_solipsism.EventLoopPolicy()


async def test_sleep(event_loop):
async def test_sleep():
event_loop = asyncio.get_running_loop()
assert event_loop.time() == 0.0
await asyncio.sleep(2)
assert event_loop.time() == 2.0
Expand All @@ -53,7 +52,9 @@ async def zzz():
]
)
@pytest.mark.parametrize('delay', [False, True])
async def test_delayed_sock_recv(method, delay, event_loop):
async def test_delayed_sock_recv(method, delay):
event_loop = asyncio.get_running_loop()

async def delayed_write(wsock):
await asyncio.sleep(1)
wsock.send(b'Hello')
Expand All @@ -76,7 +77,9 @@ async def delayed_write(wsock):


@pytest.mark.parametrize('size', [10, 10**7])
async def test_sock_sendall(size, event_loop):
async def test_sock_sendall(size):
event_loop = asyncio.get_running_loop()

async def delayed_read(rsock):
n = 0
while True:
Expand All @@ -99,7 +102,8 @@ async def delayed_read(rsock):
assert n == size


async def test_connect_existing(event_loop, mocker):
async def test_connect_existing(mocker):
event_loop = asyncio.get_running_loop()
sock1, sock2 = async_solipsism.socketpair()
transport1, protocol1 = await event_loop.connect_accepted_socket(
mocker.MagicMock, sock1)
Expand Down Expand Up @@ -128,7 +132,9 @@ async def test_stream():


@pytest.mark.parametrize('manual_socket', [False, True])
async def test_server(event_loop, manual_socket):
async def test_server(manual_socket):
event_loop = asyncio.get_running_loop()

def callback(reader, writer):
server_conn.set_result((reader, writer))

Expand Down Expand Up @@ -156,7 +162,7 @@ def callback(reader, writer):
await server.wait_closed()


async def test_unused_port(event_loop):
async def test_unused_port():
def callback(reader, writer):
pass

Expand All @@ -172,26 +178,30 @@ def callback(reader, writer):
await server2.wait_closed()


async def test_close_server(event_loop):
async def test_close_server():
event_loop = asyncio.get_running_loop()
server = await asyncio.start_server(lambda reader, writer: None, 'test.invalid', 1234)
server.close()
await server.wait_closed()
with pytest.raises(ConnectionRefusedError):
await event_loop.create_connection('test.invalid', 1234)


async def test_create_connection_no_listener(event_loop):
async def test_create_connection_no_listener():
event_loop = asyncio.get_running_loop()
with pytest.raises(ConnectionRefusedError):
await event_loop.create_connection('test.invalid', 1234)


async def test_run_in_executor_implicit(event_loop):
async def test_run_in_executor_implicit():
event_loop = asyncio.get_running_loop()
thread_id = await event_loop.run_in_executor(None, threading.get_ident)
assert isinstance(thread_id, int)
assert thread_id != threading.get_ident()


async def test_run_in_executor_explicit(event_loop):
async def test_run_in_executor_explicit():
event_loop = asyncio.get_running_loop()
my_executor = concurrent.futures.ThreadPoolExecutor(1)
expected_thread_id = my_executor.submit(threading.get_ident).result()
assert isinstance(expected_thread_id, int)
Expand All @@ -201,7 +211,8 @@ async def test_run_in_executor_explicit(event_loop):


@pytest.mark.skipif("sys.version_info < (3, 7)")
async def test_sendfile(event_loop, tmp_path):
async def test_sendfile(tmp_path):
event_loop = asyncio.get_running_loop()
tmp_file = tmp_path / 'test_sendfile.txt'
tmp_file.write_bytes(b'Hello world\n')
((reader1, writer1), (reader2, writer2)) = await async_solipsism.stream_pairs()
Expand All @@ -215,14 +226,17 @@ async def test_sendfile(event_loop, tmp_path):
await writer2.wait_closed()


async def test_call_soon_threadsafe(event_loop):
async def test_call_soon_threadsafe():
event_loop = asyncio.get_running_loop()
future = event_loop.create_future()
event_loop.call_soon_threadsafe(future.set_result, 3)
result = await future
assert result == 3


async def test_call_soon_threadsafe_wrong_thread(event_loop):
async def test_call_soon_threadsafe_wrong_thread():
event_loop = asyncio.get_running_loop()

def thread_func():
event_loop.call_soon_threadsafe(lambda: None)

Expand Down

0 comments on commit 448a203

Please sign in to comment.