From 8cf0b5702eedec1cadd2225e6665f4cdcb69b6f8 Mon Sep 17 00:00:00 2001 From: WinPlay02 Date: Tue, 2 Apr 2024 20:43:24 +0200 Subject: [PATCH] fix: sending images to the vscode extension fails, if the tensor is not local to the cpu (#63) Related-to: https://github.com/Safe-DS/DSL/pull/954 (the PR requires this fix to work) Fixes the following error: `RuntimeError: Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending.` This error only occurs on systems that support graphics acceleration of tensor operations. On other systems, the tensor is local to the CPU by default and does not need to be converted. The conversion process is a noop, if the tensor is already local to the CPU. --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> --- src/safeds_runner/server/_pipeline_manager.py | 6 + .../server/test_websocket_mock.py | 383 ++++++++++-------- 2 files changed, 223 insertions(+), 166 deletions(-) diff --git a/src/safeds_runner/server/_pipeline_manager.py b/src/safeds_runner/server/_pipeline_manager.py index 8353521..02584c2 100644 --- a/src/safeds_runner/server/_pipeline_manager.py +++ b/src/safeds_runner/server/_pipeline_manager.py @@ -229,6 +229,12 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None: value : Any Actual value of the placeholder. """ + from safeds.data.image.containers import Image + + if isinstance(value, Image): + import torch + + value = Image(value._image_tensor, torch.device("cpu")) self._placeholder_map[placeholder_name] = value placeholder_type = _get_placeholder_type(value) self._send_message( diff --git a/tests/safeds_runner/server/test_websocket_mock.py b/tests/safeds_runner/server/test_websocket_mock.py index e407624..c30168b 100644 --- a/tests/safeds_runner/server/test_websocket_mock.py +++ b/tests/safeds_runner/server/test_websocket_mock.py @@ -45,58 +45,76 @@ json.dumps({"type": "placeholder_query", "id": "123", "data": {"a": "v"}}), json.dumps({"type": "placeholder_query", "id": "123", "data": {"name": "v", "window": {"begin": "a"}}}), json.dumps({"type": "placeholder_query", "id": "123", "data": {"name": "v", "window": {"size": "a"}}}), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, + }, + ), json.dumps({"type": "program", "id": "1234", "data": {"code": {"": {"entry": ""}}}}), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "module": "2"}}, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "pipeline": "3"}}, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"module": "2", "pipeline": "3"}}, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": { - "code": {"": {"entry": ""}}, - "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": "4"}, + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "module": "2"}}, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "pipeline": "3"}}, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"module": "2", "pipeline": "3"}}, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": { + "code": {"": {"entry": ""}}, + "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": "4"}, + }, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": { + "code": {"": {"entry": ""}}, + "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": {"4": "a"}}, + }, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": "a", "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, }, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": { - "code": {"": {"entry": ""}}, - "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": {"4": "a"}}, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": "a"}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, + }, + ), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"a": {"b": "c"}}}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, }, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": "a", "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": "a"}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"a": {"b": "c"}}}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + ), ], ids=[ "no_json", @@ -164,11 +182,13 @@ def test_should_fail_message_validation_reason_general(websocket_message: str, e argvalues=[ (json.dumps({"type": "program", "id": "1234", "data": "a"}), "Message data is not a JSON object"), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, + }, + ), "No 'code' parameter given", ), ( @@ -176,73 +196,92 @@ def test_should_fail_message_validation_reason_general(websocket_message: str, e "No 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "module": "2"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "module": "2"}}, + }, + ), "Invalid 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"modulepath": "1", "pipeline": "3"}}, + }, + ), "Invalid 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"entry": ""}}, "main": {"module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": {"entry": ""}}, "main": {"module": "2", "pipeline": "3"}}, + }, + ), "Invalid 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": { - "code": {"": {"entry": ""}}, - "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": "4"}, + json.dumps( + { + "type": "program", + "id": "1234", + "data": { + "code": {"": {"entry": ""}}, + "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": "4"}, + }, }, - }), + ), "Invalid 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": { - "code": {"": {"entry": ""}}, - "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": {"4": "a"}}, + json.dumps( + { + "type": "program", + "id": "1234", + "data": { + "code": {"": {"entry": ""}}, + "main": {"modulepath": "1", "module": "2", "pipeline": "3", "other": {"4": "a"}}, + }, }, - }), + ), "Invalid 'main' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": "a", "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": "a", "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, + }, + ), "Invalid 'code' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": "a"}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": {"code": {"": "a"}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, + }, + ), "Invalid 'code' parameter given", ), ( - json.dumps({ - "type": "program", - "id": "1234", - "data": {"code": {"": {"a": {"b": "c"}}}, "main": {"modulepath": "1", "module": "2", "pipeline": "3"}}, - }), + json.dumps( + { + "type": "program", + "id": "1234", + "data": { + "code": {"": {"a": {"b": "c"}}}, + "main": {"modulepath": "1", "module": "2", "pipeline": "3"}, + }, + }, + ), "Invalid 'code' parameter given", ), ], @@ -309,19 +348,21 @@ def test_should_fail_message_validation_reason_placeholder_query( argnames="message,expected_response_runtime_error", argvalues=[ ( - json.dumps({ - "type": "program", - "id": "abcdefgh", - "data": { - "code": { - "": { - "gen_test_a": "def pipe():\n\traise Exception('Test Exception')\n", - "gen_test_a_pipe": "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()", + json.dumps( + { + "type": "program", + "id": "abcdefgh", + "data": { + "code": { + "": { + "gen_test_a": "def pipe():\n\traise Exception('Test Exception')\n", + "gen_test_a_pipe": "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()", + }, }, + "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - }), + ), Message(message_type_runtime_error, "abcdefgh", {"message": "Test Exception"}), ), ], @@ -362,28 +403,31 @@ async def test_should_execute_pipeline_return_exception( argvalues=[ ( [ - json.dumps({ - "type": "program", - "id": "abcdefg", - "data": { - "code": { - "": { - "gen_test_a": ( - "import safeds_runner\n\ndef pipe():\n\tvalue1 =" - " 1\n\tsafeds_runner.save_placeholder('value1'," - " value1)\n\tsafeds_runner.save_placeholder('obj'," - " object())\n" - ), - "gen_test_a_pipe": ( - "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()" - ), + json.dumps( + { + "type": "program", + "id": "abcdefg", + "data": { + "code": { + "": { + "gen_test_a": ( + "import safeds_runner\nimport base64\nfrom safeds.data.image.containers import Image\n\ndef pipe():\n\tvalue1 =" + " 1\n\tsafeds_runner.save_placeholder('value1'," + " value1)\n\tsafeds_runner.save_placeholder('obj'," + " object())\n\tsafeds_runner.save_placeholder('image'," + " Image.from_bytes(base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAYAAACp8Z5+AAAAD0lEQVQIW2NkQAOMpAsAAADuAAVDMQ2mAAAAAElFTkSuQmCC')))\n" + ), + "gen_test_a_pipe": ( + "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()" + ), + }, }, + "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - }), + ), ], - 3, + 4, [ # Query Placeholder json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "value1", "window": {}}}), @@ -396,6 +440,7 @@ async def test_should_execute_pipeline_return_exception( # Validate Placeholder Information Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("value1", "Int")), Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("obj", "object")), + Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("image", "Image")), # Validate Progress Information Message(message_type_runtime_progress, "abcdefg", create_runtime_progress_done()), # Query Result Valid @@ -460,34 +505,36 @@ async def test_should_execute_pipeline_return_valid_placeholder( argvalues=[ ( [ - json.dumps({ - "type": "program", - "id": "123456789", - "data": { - "code": { - "": { - "gen_b": ( - "from a.stub import u\n" - "from v.u.s.testing import add1\n" - "\n" - "def c():\n" - "\ta1 = 1\n" - "\ta2 = True or False\n" - "\tprint('test2')\n" - "\tprint('new dynamic output')\n" - "\tprint(f'Add1: {add1(1, 2)}')\n" - "\treturn a1 + a2\n" - ), - "gen_b_c": "from gen_b import c\n\nif __name__ == '__main__':\n\tc()", - }, - "a": {"stub": "def u():\n\treturn 1"}, - "v.u.s": { - "testing": "import a.stub;\n\ndef add1(v1, v2):\n\treturn v1 + v2 + a.stub.u()\n", + json.dumps( + { + "type": "program", + "id": "123456789", + "data": { + "code": { + "": { + "gen_b": ( + "from a.stub import u\n" + "from v.u.s.testing import add1\n" + "\n" + "def c():\n" + "\ta1 = 1\n" + "\ta2 = True or False\n" + "\tprint('test2')\n" + "\tprint('new dynamic output')\n" + "\tprint(f'Add1: {add1(1, 2)}')\n" + "\treturn a1 + a2\n" + ), + "gen_b_c": "from gen_b import c\n\nif __name__ == '__main__':\n\tc()", + }, + "a": {"stub": "def u():\n\treturn 1"}, + "v.u.s": { + "testing": "import a.stub;\n\ndef add1(v1, v2):\n\treturn v1 + v2 + a.stub.u()\n", + }, }, + "main": {"modulepath": "", "module": "b", "pipeline": "c"}, }, - "main": {"modulepath": "", "module": "b", "pipeline": "c"}, }, - }), + ), ], Message(message_type_runtime_progress, "123456789", create_runtime_progress_done()), ), @@ -495,11 +542,13 @@ async def test_should_execute_pipeline_return_valid_placeholder( # Query Result Invalid (no pipeline exists) [ json.dumps({"type": "invalid_message_type", "id": "unknown-code-id-never-generated", "data": ""}), - json.dumps({ - "type": "placeholder_query", - "id": "unknown-code-id-never-generated", - "data": {"name": "v", "window": {}}, - }), + json.dumps( + { + "type": "placeholder_query", + "id": "unknown-code-id-never-generated", + "data": {"name": "v", "window": {}}, + }, + ), ], Message( message_type_placeholder_value, @@ -693,19 +742,21 @@ def test_windowed_placeholder(query: MessageQueryInformation, type_: str, value: argnames="query,expected_response", argvalues=[ ( - json.dumps({ - "type": "program", - "id": "abcdefgh", - "data": { - "code": { - "": { - "gen_test_a": "def pipe():\n\tpass\n", - "gen_test_a_pipe": "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()", + json.dumps( + { + "type": "program", + "id": "abcdefgh", + "data": { + "code": { + "": { + "gen_test_a": "def pipe():\n\tpass\n", + "gen_test_a_pipe": "from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()", + }, }, + "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - "main": {"modulepath": "", "module": "test_a", "pipeline": "pipe"}, }, - }), + ), Message(message_type_runtime_progress, "abcdefgh", "done"), ), ],