diff --git a/src/async_solipsism/loop.py b/src/async_solipsism/loop.py index 567bdf6..13174f7 100644 --- a/src/async_solipsism/loop.py +++ b/src/async_solipsism/loop.py @@ -24,7 +24,7 @@ from .exceptions import SolipsismError -__all__ = ('EventLoop', 'stream_pairs') +__all__ = ('EventLoop', 'EventLoopPolicy', 'stream_pairs') class EventLoop(asyncio.selector_events.BaseSelectorEventLoop): @@ -197,6 +197,11 @@ def _stop_serving(self, sock): super()._stop_serving(sock) +class EventLoopPolicy(asyncio.DefaultEventLoopPolicy): + def new_event_loop(self) -> EventLoop: + return EventLoop() + + async def stream_pairs(capacity=None): sock1, sock2 = _socket.socketpair(capacity=capacity) streams1 = await asyncio.open_connection(sock=sock1) diff --git a/test/test_loop.py b/test/test_loop.py index cdcf7d8..8ab82a0 100644 --- a/test/test_loop.py +++ b/test/test_loop.py @@ -25,13 +25,12 @@ @pytest.fixture -def event_loop(): - loop = async_solipsism.EventLoop() - yield loop - loop.close() +def event_loop_policy(): + return async_solipsism.EventLoopPolicy() -async def test_sleep(event_loop): +async def test_sleep(): + event_loop = asyncio.get_running_loop() assert event_loop.time() == 0.0 await asyncio.sleep(2) assert event_loop.time() == 2.0 @@ -53,7 +52,9 @@ async def zzz(): ] ) @pytest.mark.parametrize('delay', [False, True]) -async def test_delayed_sock_recv(method, delay, event_loop): +async def test_delayed_sock_recv(method, delay): + event_loop = asyncio.get_running_loop() + async def delayed_write(wsock): await asyncio.sleep(1) wsock.send(b'Hello') @@ -76,7 +77,9 @@ async def delayed_write(wsock): @pytest.mark.parametrize('size', [10, 10**7]) -async def test_sock_sendall(size, event_loop): +async def test_sock_sendall(size): + event_loop = asyncio.get_running_loop() + async def delayed_read(rsock): n = 0 while True: @@ -99,7 +102,8 @@ async def delayed_read(rsock): assert n == size -async def test_connect_existing(event_loop, mocker): +async def test_connect_existing(mocker): + event_loop = asyncio.get_running_loop() sock1, sock2 = async_solipsism.socketpair() transport1, protocol1 = await event_loop.connect_accepted_socket( mocker.MagicMock, sock1) @@ -128,7 +132,9 @@ async def test_stream(): @pytest.mark.parametrize('manual_socket', [False, True]) -async def test_server(event_loop, manual_socket): +async def test_server(manual_socket): + event_loop = asyncio.get_running_loop() + def callback(reader, writer): server_conn.set_result((reader, writer)) @@ -156,7 +162,7 @@ def callback(reader, writer): await server.wait_closed() -async def test_unused_port(event_loop): +async def test_unused_port(): def callback(reader, writer): pass @@ -172,7 +178,8 @@ def callback(reader, writer): await server2.wait_closed() -async def test_close_server(event_loop): +async def test_close_server(): + event_loop = asyncio.get_running_loop() server = await asyncio.start_server(lambda reader, writer: None, 'test.invalid', 1234) server.close() await server.wait_closed() @@ -180,18 +187,21 @@ async def test_close_server(event_loop): await event_loop.create_connection('test.invalid', 1234) -async def test_create_connection_no_listener(event_loop): +async def test_create_connection_no_listener(): + event_loop = asyncio.get_running_loop() with pytest.raises(ConnectionRefusedError): await event_loop.create_connection('test.invalid', 1234) -async def test_run_in_executor_implicit(event_loop): +async def test_run_in_executor_implicit(): + event_loop = asyncio.get_running_loop() thread_id = await event_loop.run_in_executor(None, threading.get_ident) assert isinstance(thread_id, int) assert thread_id != threading.get_ident() -async def test_run_in_executor_explicit(event_loop): +async def test_run_in_executor_explicit(): + event_loop = asyncio.get_running_loop() my_executor = concurrent.futures.ThreadPoolExecutor(1) expected_thread_id = my_executor.submit(threading.get_ident).result() assert isinstance(expected_thread_id, int) @@ -201,7 +211,8 @@ async def test_run_in_executor_explicit(event_loop): @pytest.mark.skipif("sys.version_info < (3, 7)") -async def test_sendfile(event_loop, tmp_path): +async def test_sendfile(tmp_path): + event_loop = asyncio.get_running_loop() tmp_file = tmp_path / 'test_sendfile.txt' tmp_file.write_bytes(b'Hello world\n') ((reader1, writer1), (reader2, writer2)) = await async_solipsism.stream_pairs() @@ -215,14 +226,17 @@ async def test_sendfile(event_loop, tmp_path): await writer2.wait_closed() -async def test_call_soon_threadsafe(event_loop): +async def test_call_soon_threadsafe(): + event_loop = asyncio.get_running_loop() future = event_loop.create_future() event_loop.call_soon_threadsafe(future.set_result, 3) result = await future assert result == 3 -async def test_call_soon_threadsafe_wrong_thread(event_loop): +async def test_call_soon_threadsafe_wrong_thread(): + event_loop = asyncio.get_running_loop() + def thread_func(): event_loop.call_soon_threadsafe(lambda: None)