Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Memoization #38

Merged
merged 12 commits into from
Jan 24, 2024
99 changes: 99 additions & 0 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import queue
import runpy
import threading
import typing
from functools import cached_property
from multiprocessing.managers import SyncManager
from pathlib import Path
from typing import Any

import simple_websocket
Expand All @@ -25,6 +27,8 @@
)
from safeds_runner.server.module_manager import InMemoryFinder

MemoizationMap: typing.TypeAlias = dict[tuple[str, tuple[Any], tuple[Any]], Any]


class PipelineManager:
"""
Expand Down Expand Up @@ -56,6 +60,10 @@ def _messages_queue_thread(self) -> threading.Thread:
daemon=True,
)

@cached_property
def _memoization_map(self) -> MemoizationMap:
return self._multiprocessing_manager.dict() # type: ignore[return-value]

def startup(self) -> None:
"""
Prepare the runner for running Safe-DS pipelines.
Expand Down Expand Up @@ -132,6 +140,7 @@ def execute_pipeline(
execution_id,
self._messages_queue,
self._placeholder_map[execution_id],
self._memoization_map,
)
process.execute()

Expand Down Expand Up @@ -176,6 +185,7 @@ def __init__(
execution_id: str,
messages_queue: queue.Queue[Message],
placeholder_map: dict[str, Any],
memoization_map: MemoizationMap,
):
"""
Create a new process which will execute the given pipeline, when started.
Expand All @@ -190,11 +200,14 @@ def __init__(
A queue to write outgoing messages to.
placeholder_map : dict[str, Any]
A map to save calculated placeholders in.
memoization_map : MemoizationMap
A map to save memoizable functions in.
"""
self._pipeline = pipeline
self._id = execution_id
self._messages_queue = messages_queue
self._placeholder_map = placeholder_map
self._memoization_map = memoization_map
self._process = multiprocessing.Process(target=self._execute, daemon=True)

def _send_message(self, message_type: str, value: dict[Any, Any] | str) -> None:
Expand Down Expand Up @@ -222,6 +235,17 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None:
create_placeholder_description(placeholder_name, placeholder_type),
)

def get_memoization_map(self) -> MemoizationMap:
"""
Get the shared memoization map.

Returns
-------
MemoizationMap
Memoization Map
"""
return self._memoization_map

def _execute(self) -> None:
logging.info(
"Executing %s.%s.%s...",
Expand Down Expand Up @@ -278,6 +302,81 @@ def runner_save_placeholder(placeholder_name: str, value: Any) -> None:
current_pipeline.save_placeholder(placeholder_name, value)


def runner_memoized_function_call(
function_name: str,
function_callable: typing.Callable,
parameters: list[Any],
hidden_parameters: list[Any],
) -> Any:
"""
Call a function that can be memoized and save the result.

If a function has been previously memoized, the previous result may be reused.

Parameters
----------
function_name : str
Fully qualified function name
function_callable : typing.Callable
Function that is called and memoized if the result was not found in the memoization map
parameters : list[Any]
List of parameters for the function
hidden_parameters : list[Any]
List of hidden parameters for the function. This is used for memoizing some impure functions.

Returns
-------
Any
The result of the specified function, if any exists
"""
if current_pipeline is None:
return None # pragma: no cover
memoization_map = current_pipeline.get_memoization_map()
key = (function_name, _convert_list_to_tuple(parameters), _convert_list_to_tuple(hidden_parameters))
if key in memoization_map:
return memoization_map[key]
result = function_callable(*parameters)
memoization_map[key] = result
return result


def _convert_list_to_tuple(values: list) -> tuple:
"""
Recursively convert a mutable list of values to an immutable tuple containing the same values, to make the values hashable.

Parameters
----------
values : list
Values that should be converted to a tuple

Returns
-------
tuple
Converted list containing all the elements of the provided list
"""
return tuple(_convert_list_to_tuple(value) if isinstance(value, list) else value for value in values)


def runner_filemtime(filename: str) -> int | None:
"""
Get the last modification timestamp of the provided file.

Parameters
----------
filename: str
Name of the file

Returns
-------
int | None
Last modification timestamp if the provided file exists, otherwise None
"""
try:
return Path(filename).stat().st_mtime_ns
except FileNotFoundError:
return None


def get_backtrace_info(error: BaseException) -> list[dict[str, Any]]:
"""
Create a simplified backtrace from an exception.
Expand Down
83 changes: 83 additions & 0 deletions tests/safeds_runner/server/test_memoization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import tempfile
import typing
from datetime import UTC, datetime
from queue import Queue
from typing import Any

import pytest
from safeds_runner.server import pipeline_manager
from safeds_runner.server.messages import MessageDataProgram, ProgramMainInformation
from safeds_runner.server.pipeline_manager import PipelineProcess


@pytest.mark.parametrize(
argnames="function_name,params,hidden_params,expected_result",
argvalues=[
("function_pure", [1, 2, 3], [], "abc"),
("function_impure_readfile", ["filea.txt"], [1234567891], "abc"),
],
ids=["function_pure", "function_impure_readfile"],
)
def test_memoization_already_present_values(
function_name: str,
params: list,
hidden_params: list,
expected_result: Any,
) -> None:
pipeline_manager.current_pipeline = PipelineProcess(
MessageDataProgram({}, ProgramMainInformation("", "", "")),
"",
Queue(),
{},
{},
)
pipeline_manager.current_pipeline.get_memoization_map()[
(
function_name,
pipeline_manager._convert_list_to_tuple(params),
pipeline_manager._convert_list_to_tuple(hidden_params),
)
] = expected_result
result = pipeline_manager.runner_memoized_function_call(function_name, lambda *_: None, params, hidden_params)
assert result == expected_result


@pytest.mark.parametrize(
argnames="function_name,function,params,hidden_params,expected_result",
argvalues=[
("function_pure", lambda a, b, c: a + b + c, [1, 2, 3], [], 6),
("function_impure_readfile", lambda filename: filename.split(".")[0], ["abc.txt"], [1234567891], "abc"),
],
ids=["function_pure", "function_impure_readfile"],
)
def test_memoization_not_present_values(
function_name: str,
function: typing.Callable,
params: list,
hidden_params: list,
expected_result: Any,
) -> None:
pipeline_manager.current_pipeline = PipelineProcess(
MessageDataProgram({}, ProgramMainInformation("", "", "")),
"",
Queue(),
{},
{},
)
# Save value in map
result = pipeline_manager.runner_memoized_function_call(function_name, function, params, hidden_params)
assert result == expected_result
# Test if value is actually saved by calling another function that does not return the expected result
result2 = pipeline_manager.runner_memoized_function_call(function_name, lambda *_: None, params, hidden_params)
assert result2 == expected_result


def test_file_mtime_exists() -> None:
with tempfile.NamedTemporaryFile() as file:
file_mtime = pipeline_manager.runner_filemtime(file.name)
assert file_mtime is not None


def test_file_mtime_not_exists() -> None:
file_mtime = pipeline_manager.runner_filemtime(f"file_not_exists.{datetime.now(tz=UTC).timestamp()}")
assert file_mtime is None