From a30a990295cfbb35ca4736d9f1c02b2598b17cce Mon Sep 17 00:00:00 2001 From: Alistair Buxton Date: Tue, 19 Nov 2019 19:56:55 +0000 Subject: [PATCH] Refactor mp pool There is now a pool object which can be re-used. It also ensures the worker processes start up correctly, preventing deadlocks where the target function never even runs. More can be done in this area. Fixes #33. --- teletext/mp.py | 126 ++++++++++++++++++++++++++------------ teletext/tests/test_mp.py | 2 +- 2 files changed, 87 insertions(+), 41 deletions(-) diff --git a/teletext/mp.py b/teletext/mp.py index 9c41735..1bb0540 100644 --- a/teletext/mp.py +++ b/teletext/mp.py @@ -1,12 +1,12 @@ import itertools import queue import signal -import time import multiprocessing as mp from .sigint import SigIntDefer + def denumerate(quit_event, work_queue, tmp_queue): """ Strips sequence numbers from work_queue items and yields the work. @@ -32,23 +32,29 @@ def renumerate(iterator, done_queue, tmp_queue): done_queue.put((n, item)) -def slave(function, quit_event, work_queue, done_queue, args, kwargs): - """The main function for subprocesses. """ +def worker(function, started_event, stopped_event, quit_event, work_queue, done_queue, args, kwargs): + """ + The main function for subprocesses. + """ + try: + signal.signal(signal.SIGINT, signal.SIG_IGN) + tmp_queue = mp.Queue() # holds work item numbers to be recombined with the result + started_event.set() + renumerate(function(denumerate(quit_event, work_queue, tmp_queue), *args, **kwargs), done_queue, tmp_queue) + finally: + stopped_event.set() - signal.signal(signal.SIGINT, signal.SIG_IGN) - tmp_queue = mp.Queue() # holds work item numbers to be recombined with the result - renumerate(function(denumerate(quit_event, work_queue, tmp_queue), *args, **kwargs), done_queue, tmp_queue) +class PureGeneratorPool(object): -def itermap(function, iterator, processes=1, *args, **kwargs): """ - Implements a multiprocessing pool similar in function to multiprocessing.Pool. - However, Pool.map(f, i) calls f on every item in i individually. f is expected - to return the result. itermap(f, i) calls f exactly once for each process it - starts, and then delivers an iterator containing work items. f is expected to - yield results. In practice, this means you can pass large objects to f and they - will only be pickled once rather than for every item in i. It also allows you - to do one-time setup at the beginning of f. + Implements a parallel processing pool similar to multiprocessing.Pool. However, + Pool.map(f, i) calls f on every item in i individually. f is expected to return + the result. PureGeneratorPool.apply(f, i) calls f exactly once for each process + it starts, and then delivers an iterator containing work items. f is expected + to yield results. In practice, this means you can pass large objects to f and + they will only be pickled once rather than for every item in i. It also allows + you to do one-time setup at the beginning of f. f must be a "pure generator". This means it must yield exactly one result for each item in the iterator, and that result must only depend on the current @@ -59,37 +65,66 @@ def itermap(function, iterator, processes=1, *args, **kwargs): is a pure generator. - itermap() preserves the ordering of items in the input iterator. + apply() preserves the ordering of items in the input iterator. """ - if processes <= 1: - yield from function(iterator, *args, **kwargs) - else: - iterator = enumerate(iterator) + + def __init__(self, function, processes=1, *args, **kwargs): + self._processes = processes + self._function = function + self._args = args + self._kwargs = kwargs + self._pool = [] + + if self._processes <= 1: + # Don't need to set up processes, queues, events. + return ctx = mp.get_context('spawn') # Work items are placed on this queue by the main process. - work_queue = ctx.Queue() + self._work_queue = ctx.Queue() # Sub-processes place results on this queue. - done_queue = ctx.Queue() + self._done_queue = ctx.Queue() # Tells sub-processes that we are done and they should exit. - quit_event = ctx.Event() - - pool = [ctx.Process( - target=slave, args=(function, quit_event, work_queue, done_queue, args, kwargs), daemon=True - ) for id in range(processes)] - - with SigIntDefer() as sigint: - try: - for p in pool: - p.start() + self._quit_event = ctx.Event() + + for id in range(processes): + started_event = ctx.Event() + stopped_event = ctx.Event() + p = ctx.Process(target=worker, daemon=True, args=( + function, started_event, stopped_event, self._quit_event, + self._work_queue, self._done_queue, self._args, self._kwargs + )) + self._pool.append((p, started_event, stopped_event)) + + @property + def started(self): + return all(p[1].is_set() for p in self._pool) + + @property + def stopped(self): + return any(p[2].is_set() for p in self._pool) + + def __enter__(self): + for p in self._pool: + p[0].start() + for p in self._pool: + if not p[1].wait(timeout=1): + raise TimeoutError('Timed out waiting for worker process to start.') + return self + + def apply(self, iterable): + if self._pool: + iterable = enumerate(iterable) + + with SigIntDefer() as sigint: sent_count = 0 received_count = 0 # Prime the queue with some items. - for item in itertools.islice(iterator, 100): - work_queue.put(item) + for item in itertools.islice(iterable, 100): + self._work_queue.put(item) sent_count += 1 # Dict to use for sorting received items back into @@ -97,24 +132,35 @@ def itermap(function, iterator, processes=1, *args, **kwargs): received = {} while received_count < sent_count: - n, item = done_queue.get() + n, item = self._done_queue.get() received[n] = item while received_count in received: yield received[received_count] del received[received_count] received_count += 1 if sigint.fired: - quit_event.set() + self._quit_event.set() else: try: - work_queue.put(next(iterator)) + self._work_queue.put(next(iterable)) sent_count += 1 except StopIteration: - quit_event.set() + self._quit_event.set() + else: + # Single process implementation. + yield from self._function(iterable, *self._args, **self._kwargs) + + def __exit__(self, *args): + for p in self._pool: + p[0].join() + + +def itermap(function, iterable, processes=1, *args, **kwargs): + + """One-shot function to make a PureGeneratorPool and apply it.""" - finally: - for p in pool: - p.join() + with PureGeneratorPool(function, processes, *args, **kwargs) as pool: + yield from pool.apply(iterable) if __name__ in ['__main__', '__mp_main__']: diff --git a/teletext/tests/test_mp.py b/teletext/tests/test_mp.py index e67ba37..d7cfc9c 100644 --- a/teletext/tests/test_mp.py +++ b/teletext/tests/test_mp.py @@ -18,7 +18,7 @@ def test_single(self): result = list(itermap(func, self.input, processes=1)) self.assertListEqual(result, self.result) - @unittest.skip # breaks coverage - https://github.com/nedbat/coveragepy/issues/745 + #@unittest.skip # breaks coverage - https://github.com/nedbat/coveragepy/issues/745 def test_multi(self): result = list(itermap(func, self.input, processes=2)) self.assertListEqual(result, self.result)