Skip to content

Commit

Permalink
Refactor mp pool
Browse files Browse the repository at this point in the history
There is now a pool object which can be re-used. It also ensures
the worker processes start up correctly, preventing deadlocks
where the target function never even runs. More can be done in
this area.

Fixes #33.
  • Loading branch information
ali1234 committed Nov 19, 2019
1 parent aeb9e0d commit a30a990
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 41 deletions.
126 changes: 86 additions & 40 deletions teletext/mp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import itertools
import queue
import signal
import time

import multiprocessing as mp

from .sigint import SigIntDefer


def denumerate(quit_event, work_queue, tmp_queue):
"""
Strips sequence numbers from work_queue items and yields the work.
Expand All @@ -32,23 +32,29 @@ def renumerate(iterator, done_queue, tmp_queue):
done_queue.put((n, item))


def slave(function, quit_event, work_queue, done_queue, args, kwargs):
"""The main function for subprocesses. """
def worker(function, started_event, stopped_event, quit_event, work_queue, done_queue, args, kwargs):
"""
The main function for subprocesses.
"""
try:
signal.signal(signal.SIGINT, signal.SIG_IGN)
tmp_queue = mp.Queue() # holds work item numbers to be recombined with the result
started_event.set()
renumerate(function(denumerate(quit_event, work_queue, tmp_queue), *args, **kwargs), done_queue, tmp_queue)
finally:
stopped_event.set()

signal.signal(signal.SIGINT, signal.SIG_IGN)
tmp_queue = mp.Queue() # holds work item numbers to be recombined with the result
renumerate(function(denumerate(quit_event, work_queue, tmp_queue), *args, **kwargs), done_queue, tmp_queue)

class PureGeneratorPool(object):

def itermap(function, iterator, processes=1, *args, **kwargs):
"""
Implements a multiprocessing pool similar in function to multiprocessing.Pool.
However, Pool.map(f, i) calls f on every item in i individually. f is expected
to return the result. itermap(f, i) calls f exactly once for each process it
starts, and then delivers an iterator containing work items. f is expected to
yield results. In practice, this means you can pass large objects to f and they
will only be pickled once rather than for every item in i. It also allows you
to do one-time setup at the beginning of f.
Implements a parallel processing pool similar to multiprocessing.Pool. However,
Pool.map(f, i) calls f on every item in i individually. f is expected to return
the result. PureGeneratorPool.apply(f, i) calls f exactly once for each process
it starts, and then delivers an iterator containing work items. f is expected
to yield results. In practice, this means you can pass large objects to f and
they will only be pickled once rather than for every item in i. It also allows
you to do one-time setup at the beginning of f.
f must be a "pure generator". This means it must yield exactly one result for
each item in the iterator, and that result must only depend on the current
Expand All @@ -59,62 +65,102 @@ def itermap(function, iterator, processes=1, *args, **kwargs):
is a pure generator.
itermap() preserves the ordering of items in the input iterator.
apply() preserves the ordering of items in the input iterator.
"""
if processes <= 1:
yield from function(iterator, *args, **kwargs)
else:
iterator = enumerate(iterator)

def __init__(self, function, processes=1, *args, **kwargs):
self._processes = processes
self._function = function
self._args = args
self._kwargs = kwargs
self._pool = []

if self._processes <= 1:
# Don't need to set up processes, queues, events.
return

ctx = mp.get_context('spawn')

# Work items are placed on this queue by the main process.
work_queue = ctx.Queue()
self._work_queue = ctx.Queue()
# Sub-processes place results on this queue.
done_queue = ctx.Queue()
self._done_queue = ctx.Queue()
# Tells sub-processes that we are done and they should exit.
quit_event = ctx.Event()

pool = [ctx.Process(
target=slave, args=(function, quit_event, work_queue, done_queue, args, kwargs), daemon=True
) for id in range(processes)]

with SigIntDefer() as sigint:
try:
for p in pool:
p.start()
self._quit_event = ctx.Event()

for id in range(processes):
started_event = ctx.Event()
stopped_event = ctx.Event()
p = ctx.Process(target=worker, daemon=True, args=(
function, started_event, stopped_event, self._quit_event,
self._work_queue, self._done_queue, self._args, self._kwargs
))
self._pool.append((p, started_event, stopped_event))

@property
def started(self):
return all(p[1].is_set() for p in self._pool)

@property
def stopped(self):
return any(p[2].is_set() for p in self._pool)

def __enter__(self):
for p in self._pool:
p[0].start()
for p in self._pool:
if not p[1].wait(timeout=1):
raise TimeoutError('Timed out waiting for worker process to start.')
return self

def apply(self, iterable):
if self._pool:
iterable = enumerate(iterable)

with SigIntDefer() as sigint:

sent_count = 0
received_count = 0

# Prime the queue with some items.
for item in itertools.islice(iterator, 100):
work_queue.put(item)
for item in itertools.islice(iterable, 100):
self._work_queue.put(item)
sent_count += 1

# Dict to use for sorting received items back into
# their original order.
received = {}

while received_count < sent_count:
n, item = done_queue.get()
n, item = self._done_queue.get()
received[n] = item
while received_count in received:
yield received[received_count]
del received[received_count]
received_count += 1
if sigint.fired:
quit_event.set()
self._quit_event.set()
else:
try:
work_queue.put(next(iterator))
self._work_queue.put(next(iterable))
sent_count += 1
except StopIteration:
quit_event.set()
self._quit_event.set()
else:
# Single process implementation.
yield from self._function(iterable, *self._args, **self._kwargs)

def __exit__(self, *args):
for p in self._pool:
p[0].join()


def itermap(function, iterable, processes=1, *args, **kwargs):

"""One-shot function to make a PureGeneratorPool and apply it."""

finally:
for p in pool:
p.join()
with PureGeneratorPool(function, processes, *args, **kwargs) as pool:
yield from pool.apply(iterable)


if __name__ in ['__main__', '__mp_main__']:
Expand Down
2 changes: 1 addition & 1 deletion teletext/tests/test_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_single(self):
result = list(itermap(func, self.input, processes=1))
self.assertListEqual(result, self.result)

@unittest.skip # breaks coverage - https://github.com/nedbat/coveragepy/issues/745
#@unittest.skip # breaks coverage - https://github.com/nedbat/coveragepy/issues/745
def test_multi(self):
result = list(itermap(func, self.input, processes=2))
self.assertListEqual(result, self.result)

0 comments on commit a30a990

Please sign in to comment.