Skip to content

Commit

Permalink
communicate with mpi_comm_spawn_child
Browse files Browse the repository at this point in the history
  • Loading branch information
rythorpe committed Jul 17, 2022
1 parent 283847d commit c616558
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/howto/plot_simulate_mpi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@

# to create a parent MPI job that uses MPIBackend to spawn child jobs, set
# mpi_comm_spawn=True and call this script from terminal using
# $ mpiexec -np 1 --oversubscribe python -m mpi4py /home/ryan/hnn-core/examples/howto/plot_simulate_mpi_backend.py
with MPIBackend(n_procs=5, mpi_cmd='mpiexec', mpi_comm_spawn=False):
# ``$ mpiexec -np 1 --oversubscribe python -m mpi4py /path/to/hnn-core/examples/howto/plot_simulate_mpi_backend.py``
with MPIBackend(n_procs=5, mpi_cmd='mpiexec', mpi_comm_spawn=True):
dpls = simulate_dipole(net, tstop=200., n_trials=1)

trial_idx = 0
Expand Down
3 changes: 0 additions & 3 deletions hnn_core/mpi_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import pickle
import base64
import re
import shlex
from os import environ
from mpi4py import MPI

from hnn_core.parallel_backends import _extract_data, _extract_data_length

Expand Down
40 changes: 38 additions & 2 deletions hnn_core/mpi_comm_spawn_child.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""Script for running parallel simulations with MPI when called with mpiexec.
This script is called directly from MPIBackend.simulate()
"""

# Authors: Blake Caldwell <[email protected]>
# Ryan Thorpe <[email protected]>

from hnn_core.network_builder import _simulate_single_trial


class MPISimulation(object):
"""The MPISimulation class.
Parameters
----------
skip_mpi_import : bool | None
Expand All @@ -22,6 +30,7 @@ def __init__(self, skip_mpi_import=False):
else:
from mpi4py import MPI

self.intercomm = MPI.Comm.Get_parent()
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()

Expand All @@ -31,5 +40,32 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
# skip Finalize() if we didn't import MPI on __init__
if hasattr(self, 'comm'):
from mpi4py import MPI
MPI.Finalize()
self.intercomm.Disconnect()

def _read_net(self):
"""Read net and associated objects broadcasted to all ranks on stdin"""

return self.intercomm.bcast(None, root=0)

def run(self, net, tstop, dt, n_trials):
"""Run MPI simulation(s) and write results to stderr"""

sim_data = []
for trial_idx in range(n_trials):
single_sim_data = _simulate_single_trial(net, tstop, dt, trial_idx)

# go ahead and append trial data for each rank, though
# only rank 0 has data that should be sent back to MPIBackend
sim_data.append(single_sim_data)

return sim_data


if __name__ == '__main__':
"""This file is called on command-line from nrniv"""

with MPISimulation() as mpi_sim:
net, tstop, dt, n_trials = mpi_sim._read_net()
sim_data = mpi_sim.run(net, tstop, dt, n_trials)
if mpi_sim.rank == 0:
mpi_sim.intercomm.send(sim_data, dest=0)
11 changes: 5 additions & 6 deletions hnn_core/parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,18 @@ def spawn_job(command, obj, n_procs, info=None):
from mpi4py import MPI

if info is None:
intercomm = MPI.COMM_SELF.Spawn('nrniv', args=command,
intercomm = MPI.COMM_SELF.Spawn(command[0], args=command[1:],
maxprocs=n_procs)
else:
intercomm = MPI.COMM_SELF.Spawn('nrniv', args=command,
intercomm = MPI.COMM_SELF.Spawn(command[0], args=command[1:],
info=info,
maxprocs=n_procs)

# send Network + sim param objects to each child process
intercomm.bcast(obj, root=MPI.ROOT)

# receive data
child_data = intercomm.recv(source=0)
child_data = intercomm.recv(source=MPI.ANY_SOURCE)

# close spawned communicator
intercomm.Disconnect()
Expand Down Expand Up @@ -742,7 +742,7 @@ def __exit__(self, type, value, traceback):
_BACKEND = self._old_backend

# always kill nrniv processes for good measure
if self.n_procs > 1:
if (not self.mpi_comm_spawn) and self.n_procs > 1:
kill_proc_name('nrniv')

def simulate(self, net, tstop, dt, n_trials, postproc=False):
Expand Down Expand Up @@ -781,7 +781,6 @@ def simulate(self, net, tstop, dt, n_trials, postproc=False):

if self.mpi_comm_spawn:
sim_data = spawn_job(
parent_comm=self._selfcomm,
command=self.command,
obj=sim_objs,
n_procs=self.n_procs,
Expand Down

0 comments on commit c616558

Please sign in to comment.