diff --git a/pyproject.toml b/pyproject.toml index 93f709e..7d127f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pfun" -version = "0.12.1" +version = "0.12.2" description = "" authors = ["Sune Debel "] readme = "README.md" diff --git a/src/pfun/effect.pyx b/src/pfun/effect.pyx index 9de1f34..185628e 100644 --- a/src/pfun/effect.pyx +++ b/src/pfun/effect.pyx @@ -27,18 +27,25 @@ cdef class RuntimeEnv: """ cdef object r cdef object exit_stack - cdef object process_executor - cdef object thread_executor + cdef object max_processes + cdef object max_threads + cdef readonly object process_executor + cdef readonly object thread_executor - def __cinit__(self, r, exit_stack, process_executor, thread_executor): + def __cinit__(self, r, exit_stack, max_processes, max_threads): self.r = r self.exit_stack = exit_stack - self.process_executor = process_executor - self.thread_executor = thread_executor + self.max_processes = max_processes + self.max_threads = max_threads + self.process_executor = None + self.thread_executor = None async def run_in_process_executor(self, f, *args, **kwargs): loop = asyncio.get_running_loop() payload = dill.dumps((f, args, kwargs)) + if self.process_executor is None: + self.process_executor = ProcessPoolExecutor(max_workers=self.max_processes) + self.exit_stack.enter_context(self.process_executor) return dill.loads( await loop.run_in_executor( self.process_executor, run_dill_encoded, payload @@ -47,6 +54,9 @@ cdef class RuntimeEnv: async def run_in_thread_executor(self, f, *args, **kwargs): loop = asyncio.get_running_loop() + if self.thread_executor is None: + self.thread_executor = ThreadPoolExecutor(max_workers=self.max_threads) + self.exit_stack.enter_context(self.thread_executor) return await loop.run_in_executor( self.thread_executor, lambda: f(*args, **kwargs) ) @@ -100,12 +110,8 @@ cdef class CEffect: Exception """ stack = AsyncExitStack() - process_executor = ProcessPoolExecutor(max_workers=max_processes) - thread_executor = ThreadPoolExecutor(max_workers=max_threads) async with stack: - stack.enter_context(process_executor) - stack.enter_context(thread_executor) - env = RuntimeEnv(r, stack, process_executor, thread_executor) + env = RuntimeEnv(r, stack, max_processes, max_threads) effect = await self.do(env) if isinstance(effect, CSuccess): return effect.result diff --git a/tests/test_effect.py b/tests/test_effect.py index e1413e8..2544d0f 100644 --- a/tests/test_effect.py +++ b/tests/test_effect.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from subprocess import CalledProcessError from unittest import mock @@ -236,6 +237,15 @@ def test_catch_cpu_bound(self, f): def test_catch_io_bound(self, f): assert effect.catch_io_bound(Exception)(f)(None).run(None) == f(None) + @pytest.mark.asyncio + @given(anything()) + async def test_process_and_thread_pool_initialized_lazily(self, value): + with ExitStack() as stack: + env = effect.RuntimeEnv(None, stack, 1, 1) + await effect.success(value).do(env) + assert env.process_executor is None + assert env.thread_executor is None + def test_success_repr(self): assert repr(effect.success('value')) == 'success(\'value\')'