From e5e7011eca50a49acd7f8c0ca937ad43faf393e6 Mon Sep 17 00:00:00 2001 From: WinPlay02 Date: Sun, 21 Apr 2024 22:29:58 +0200 Subject: [PATCH] feat: prepare and pool processes (#87) Closes #85 ### Summary of Changes - Use a process pool to keep started processes waiting - The max. amount of pipeline processes is now set to `4`. - Reuse started processes. This should be correct, as the same pipeline process cannot be used by multiple pipelines at the same time. As the `metapath` is reset to remove the custom generated Safe-DS pipeline code, only global library imports (and settings) should remain. If this is a concern, `maxtasksperchild` can be set to `1`, in which case pipeline processes are not reused. - Reuse shared memory location for saving placeholders, if the memoization infrastructure has added such a location to the object being saved --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann --- src/safeds_runner/server/_pipeline_manager.py | 45 ++++++++++-- .../server/test_websocket_mock.py | 73 +++++++++---------- 2 files changed, 75 insertions(+), 43 deletions(-) diff --git a/src/safeds_runner/server/_pipeline_manager.py b/src/safeds_runner/server/_pipeline_manager.py index 4735cb5..d40584d 100644 --- a/src/safeds_runner/server/_pipeline_manager.py +++ b/src/safeds_runner/server/_pipeline_manager.py @@ -2,6 +2,7 @@ import asyncio import json +import linecache import logging import multiprocessing import os @@ -9,6 +10,7 @@ import runpy import threading import typing +from concurrent.futures import ProcessPoolExecutor from functools import cached_property from multiprocessing.managers import SyncManager from pathlib import Path @@ -17,6 +19,13 @@ import stack_data from ._memoization_map import MemoizationMap +from ._memoization_utils import ( + ExplicitIdentityWrapper, + ExplicitIdentityWrapperLazy, + _has_explicit_identity_memory, + _is_deterministically_hashable, + _is_not_primitive, +) from ._messages import ( Message, MessageDataProgram, @@ -53,6 +62,10 @@ def _multiprocessing_manager(self) -> SyncManager: def _messages_queue(self) -> queue.Queue[Message]: return self._multiprocessing_manager.Queue() + @cached_property + def _process_pool(self) -> ProcessPoolExecutor: + return ProcessPoolExecutor(max_workers=4, mp_context=multiprocessing.get_context("spawn")) + @cached_property def _messages_queue_thread(self) -> threading.Thread: return threading.Thread(target=self._handle_queue_messages, daemon=True, args=(asyncio.get_event_loop(),)) @@ -75,6 +88,8 @@ def startup(self) -> None: _mq = self._messages_queue # Initialize it here before starting a thread to avoid potential race condition if not self._messages_queue_thread.is_alive(): self._messages_queue_thread.start() + # Ensure that pool is started + _pool = self._process_pool def _handle_queue_messages(self, event_loop: asyncio.AbstractEventLoop) -> None: """ @@ -144,7 +159,7 @@ def execute_pipeline( self._placeholder_map[execution_id], self._memoization_map, ) - process.execute() + process.execute(self._process_pool) def get_placeholder(self, execution_id: str, placeholder_name: str) -> tuple[str | None, Any]: """ @@ -167,6 +182,8 @@ def get_placeholder(self, execution_id: str, placeholder_name: str) -> tuple[str if placeholder_name not in self._placeholder_map[execution_id]: return None, None value = self._placeholder_map[execution_id][placeholder_name] + if isinstance(value, ExplicitIdentityWrapper | ExplicitIdentityWrapperLazy): + value = value.value return _get_placeholder_type(value), value def shutdown(self) -> None: @@ -176,6 +193,7 @@ def shutdown(self) -> None: This should only be called if this PipelineManager is not intended to be reused again. """ self._multiprocessing_manager.shutdown() + self._process_pool.shutdown(wait=True, cancel_futures=True) class PipelineProcess: @@ -210,7 +228,6 @@ def __init__( self._messages_queue = messages_queue self._placeholder_map = placeholder_map self._memoization_map = memoization_map - self._process = multiprocessing.Process(target=self._execute, daemon=True) def _send_message(self, message_type: str, value: dict[Any, Any] | str) -> None: self._messages_queue.put(Message(message_type, self._id, value)) @@ -236,8 +253,16 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None: import torch value = Image(value._image_tensor, torch.device("cpu")) - self._placeholder_map[placeholder_name] = value placeholder_type = _get_placeholder_type(value) + if _is_deterministically_hashable(value) and _has_explicit_identity_memory(value): + value = ExplicitIdentityWrapperLazy.existing(value) + elif ( + not _is_deterministically_hashable(value) + and _is_not_primitive(value) + and _has_explicit_identity_memory(value) + ): + value = ExplicitIdentityWrapper.existing(value) + self._placeholder_map[placeholder_name] = value self._send_message( message_type_placeholder_type, create_placeholder_description(placeholder_name, placeholder_type), @@ -284,15 +309,23 @@ def _execute(self) -> None: except BaseException as error: # noqa: BLE001 self._send_exception(error) finally: + linecache.clearcache() pipeline_finder.detach() - def execute(self) -> None: + def _catch_subprocess_error(self, error: BaseException) -> None: + # This is a callback to log an unexpected failure, executing this is never expected + logging.exception("Pipeline process unexpectedly failed", exc_info=error) # pragma: no cover + + def execute(self, pool: ProcessPoolExecutor) -> None: """ - Execute this pipeline in a newly created process. + Execute this pipeline in a process from the provided process pool. Results, progress and errors are communicated back to the main process. """ - self._process.start() + future = pool.submit(self._execute) + exception = future.exception() + if exception is not None: + self._catch_subprocess_error(exception) # pragma: no cover # Pipeline process object visible in child process diff --git a/tests/safeds_runner/server/test_websocket_mock.py b/tests/safeds_runner/server/test_websocket_mock.py index 195b441..afd3adc 100644 --- a/tests/safeds_runner/server/test_websocket_mock.py +++ b/tests/safeds_runner/server/test_websocket_mock.py @@ -2,7 +2,6 @@ import json import logging import multiprocessing -import os import sys import time import typing @@ -142,7 +141,8 @@ ) @pytest.mark.asyncio() async def test_should_fail_message_validation_ws(websocket_message: str) -> None: - test_client = SafeDsServer().app.test_client() + sds_server = SafeDsServer() + test_client = sds_server.app.test_client() async with test_client.websocket("/WSMain") as test_websocket: await test_websocket.send(websocket_message) disconnected = False @@ -151,6 +151,7 @@ async def test_should_fail_message_validation_ws(websocket_message: str) -> None except WebsocketDisconnectError as _disconnect: disconnected = True assert disconnected + sds_server.app_pipeline_manager.shutdown() @pytest.mark.parametrize( @@ -352,13 +353,6 @@ def test_should_fail_message_validation_reason_placeholder_query( assert invalid_message == exception_message -@pytest.mark.skipif( - sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None, - reason=( - "skipping multiprocessing tests on windows if coverage is enabled, as pytest " - "causes Manager to hang, when using multiprocessing coverage" - ), -) @pytest.mark.parametrize( argnames="message,expected_response_runtime_error", argvalues=[ @@ -388,7 +382,8 @@ async def test_should_execute_pipeline_return_exception( message: str, expected_response_runtime_error: Message, ) -> None: - test_client = SafeDsServer().app.test_client() + sds_server = SafeDsServer() + test_client = sds_server.app.test_client() async with test_client.websocket("/WSMain") as test_websocket: await test_websocket.send(message) received_message = await test_websocket.receive() @@ -404,15 +399,9 @@ async def test_should_execute_pipeline_return_exception( assert isinstance(frame["file"], str) assert "line" in frame assert isinstance(frame["line"], int) + sds_server.app_pipeline_manager.shutdown() -@pytest.mark.skipif( - sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None, - reason=( - "skipping multiprocessing tests on windows if coverage is enabled, as pytest " - "causes Manager to hang, when using multiprocessing coverage" - ), -) @pytest.mark.parametrize( argnames="initial_messages,initial_execution_message_wait,appended_messages,expected_responses", argvalues=[ @@ -426,11 +415,15 @@ async def test_should_execute_pipeline_return_exception( "code": { "": { "gen_test_a": ( - "import safeds_runner\nimport base64\nfrom safeds.data.image.containers import Image\n\ndef pipe():\n\tvalue1 =" + "import safeds_runner\nimport base64\nfrom safeds.data.image.containers import Image\nfrom safeds.data.tabular.containers import Table\nimport safeds_runner\nfrom safeds_runner.server._json_encoder import SafeDsEncoder\n\ndef pipe():\n\tvalue1 =" " 1\n\tsafeds_runner.save_placeholder('value1'," " value1)\n\tsafeds_runner.save_placeholder('obj'," " object())\n\tsafeds_runner.save_placeholder('image'," - " Image.from_bytes(base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAYAAACp8Z5+AAAAD0lEQVQIW2NkQAOMpAsAAADuAAVDMQ2mAAAAAElFTkSuQmCC')))\n" + " Image.from_bytes(base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAYAAACp8Z5+AAAAD0lEQVQIW2NkQAOMpAsAAADuAAVDMQ2mAAAAAElFTkSuQmCC')))\n\t" + "table = safeds_runner.memoized_static_call(\"safeds.data.tabular.containers.Table.from_dict\", Table.from_dict, [{'a': [1, 2], 'b': [3, 4]}], [])\n\t" + "safeds_runner.save_placeholder('table',table)\n\t" + 'object_mem = safeds_runner.memoized_static_call("random.object.call", SafeDsEncoder, [], [])\n\t' + "safeds_runner.save_placeholder('object_mem',object_mem)\n" ), "gen_test_a_pipe": ( "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()" @@ -442,10 +435,12 @@ async def test_should_execute_pipeline_return_exception( }, ), ], - 4, + 6, [ # Query Placeholder json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "value1", "window": {}}}), + # Query Placeholder (memoized type) + json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "table", "window": {}}}), # Query not displayable Placeholder json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "obj", "window": {}}}), # Query invalid placeholder @@ -456,6 +451,12 @@ async def test_should_execute_pipeline_return_exception( Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("value1", "Int")), Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("obj", "object")), Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("image", "Image")), + Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("table", "Table")), + Message( + message_type_placeholder_type, + "abcdefg", + create_placeholder_description("object_mem", "SafeDsEncoder"), + ), # Validate Progress Information Message(message_type_runtime_progress, "abcdefg", create_runtime_progress_done()), # Query Result Valid @@ -464,6 +465,12 @@ async def test_should_execute_pipeline_return_exception( "abcdefg", create_placeholder_value(MessageQueryInformation("value1"), "Int", 1), ), + # Query Result Valid (memoized) + Message( + message_type_placeholder_value, + "abcdefg", + create_placeholder_value(MessageQueryInformation("table"), "Table", {"a": [1, 2], "b": [3, 4]}), + ), # Query Result not displayable Message( message_type_placeholder_value, @@ -489,7 +496,8 @@ async def test_should_execute_pipeline_return_valid_placeholder( expected_responses: list[Message], ) -> None: # Initial execution - test_client = SafeDsServer().app.test_client() + sds_server = SafeDsServer() + test_client = sds_server.app.test_client() async with test_client.websocket("/WSMain") as test_websocket: for message in initial_messages: await test_websocket.send(message) @@ -506,15 +514,9 @@ async def test_should_execute_pipeline_return_valid_placeholder( received_message = await test_websocket.receive() next_message = Message.from_dict(json.loads(received_message)) assert next_message == expected_responses.pop(0) + sds_server.app_pipeline_manager.shutdown() -@pytest.mark.skipif( - sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None, - reason=( - "skipping multiprocessing tests on windows if coverage is enabled, as pytest " - "causes Manager to hang, when using multiprocessing coverage" - ), -) @pytest.mark.parametrize( argnames="messages,expected_response", argvalues=[ @@ -576,22 +578,17 @@ async def test_should_execute_pipeline_return_valid_placeholder( ) @pytest.mark.asyncio() async def test_should_successfully_execute_simple_flow(messages: list[str], expected_response: Message) -> None: - test_client = SafeDsServer().app.test_client() + sds_server = SafeDsServer() + test_client = sds_server.app.test_client() async with test_client.websocket("/WSMain") as test_websocket: for message in messages: await test_websocket.send(message) received_message = await test_websocket.receive() query_result_invalid = Message.from_dict(json.loads(received_message)) assert query_result_invalid == expected_response + sds_server.app_pipeline_manager.shutdown() -@pytest.mark.skipif( - sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None, - reason=( - "skipping multiprocessing tests on windows if coverage is enabled, as pytest " - "causes Manager to hang, when using multiprocessing coverage" - ), -) @pytest.mark.parametrize( argnames="messages", argvalues=[ @@ -613,10 +610,12 @@ def helper_should_shut_itself_down_run_in_subprocess(sub_messages: list[str]) -> async def helper_should_shut_itself_down_run_in_subprocess_async(sub_messages: list[str]) -> None: - test_client = SafeDsServer().app.test_client() + sds_server = SafeDsServer() + test_client = sds_server.app.test_client() async with test_client.websocket("/WSMain") as test_websocket: for message in sub_messages: await test_websocket.send(message) + sds_server.app_pipeline_manager.shutdown() @pytest.mark.timeout(45)