From c616558470963e90253d72b349cbc0dc5cc8e2eb Mon Sep 17 00:00:00 2001 From: rythorpe Date: Sun, 17 Jul 2022 17:36:58 -0400 Subject: [PATCH] communicate with mpi_comm_spawn_child --- examples/howto/plot_simulate_mpi_backend.py | 4 +-- hnn_core/mpi_child.py | 3 -- hnn_core/mpi_comm_spawn_child.py | 40 +++++++++++++++++++-- hnn_core/parallel_backends.py | 11 +++--- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/examples/howto/plot_simulate_mpi_backend.py b/examples/howto/plot_simulate_mpi_backend.py index bb60608f5..84a3d3cbb 100644 --- a/examples/howto/plot_simulate_mpi_backend.py +++ b/examples/howto/plot_simulate_mpi_backend.py @@ -49,8 +49,8 @@ # to create a parent MPI job that uses MPIBackend to spawn child jobs, set # mpi_comm_spawn=True and call this script from terminal using -# $ mpiexec -np 1 --oversubscribe python -m mpi4py /home/ryan/hnn-core/examples/howto/plot_simulate_mpi_backend.py -with MPIBackend(n_procs=5, mpi_cmd='mpiexec', mpi_comm_spawn=False): +# ``$ mpiexec -np 1 --oversubscribe python -m mpi4py /path/to/hnn-core/examples/howto/plot_simulate_mpi_backend.py`` +with MPIBackend(n_procs=5, mpi_cmd='mpiexec', mpi_comm_spawn=True): dpls = simulate_dipole(net, tstop=200., n_trials=1) trial_idx = 0 diff --git a/hnn_core/mpi_child.py b/hnn_core/mpi_child.py index e87e4ced4..3e3f99413 100644 --- a/hnn_core/mpi_child.py +++ b/hnn_core/mpi_child.py @@ -9,9 +9,6 @@ import pickle import base64 import re -import shlex -from os import environ -from mpi4py import MPI from hnn_core.parallel_backends import _extract_data, _extract_data_length diff --git a/hnn_core/mpi_comm_spawn_child.py b/hnn_core/mpi_comm_spawn_child.py index 95e245b5a..3444de894 100644 --- a/hnn_core/mpi_comm_spawn_child.py +++ b/hnn_core/mpi_comm_spawn_child.py @@ -1,8 +1,16 @@ +"""Script for running parallel simulations with MPI when called with mpiexec. +This script is called directly from MPIBackend.simulate() +""" +# Authors: Blake Caldwell +# Ryan Thorpe + +from hnn_core.network_builder import _simulate_single_trial class MPISimulation(object): """The MPISimulation class. + Parameters ---------- skip_mpi_import : bool | None @@ -22,6 +30,7 @@ def __init__(self, skip_mpi_import=False): else: from mpi4py import MPI + self.intercomm = MPI.Comm.Get_parent() self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() @@ -31,5 +40,32 @@ def __enter__(self): def __exit__(self, type, value, traceback): # skip Finalize() if we didn't import MPI on __init__ if hasattr(self, 'comm'): - from mpi4py import MPI - MPI.Finalize() \ No newline at end of file + self.intercomm.Disconnect() + + def _read_net(self): + """Read net and associated objects broadcasted to all ranks on stdin""" + + return self.intercomm.bcast(None, root=0) + + def run(self, net, tstop, dt, n_trials): + """Run MPI simulation(s) and write results to stderr""" + + sim_data = [] + for trial_idx in range(n_trials): + single_sim_data = _simulate_single_trial(net, tstop, dt, trial_idx) + + # go ahead and append trial data for each rank, though + # only rank 0 has data that should be sent back to MPIBackend + sim_data.append(single_sim_data) + + return sim_data + + +if __name__ == '__main__': + """This file is called on command-line from nrniv""" + + with MPISimulation() as mpi_sim: + net, tstop, dt, n_trials = mpi_sim._read_net() + sim_data = mpi_sim.run(net, tstop, dt, n_trials) + if mpi_sim.rank == 0: + mpi_sim.intercomm.send(sim_data, dest=0) diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index 3f1394efd..fab6e1581 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -118,18 +118,18 @@ def spawn_job(command, obj, n_procs, info=None): from mpi4py import MPI if info is None: - intercomm = MPI.COMM_SELF.Spawn('nrniv', args=command, + intercomm = MPI.COMM_SELF.Spawn(command[0], args=command[1:], maxprocs=n_procs) else: - intercomm = MPI.COMM_SELF.Spawn('nrniv', args=command, + intercomm = MPI.COMM_SELF.Spawn(command[0], args=command[1:], info=info, maxprocs=n_procs) - + # send Network + sim param objects to each child process intercomm.bcast(obj, root=MPI.ROOT) # receive data - child_data = intercomm.recv(source=0) + child_data = intercomm.recv(source=MPI.ANY_SOURCE) # close spawned communicator intercomm.Disconnect() @@ -742,7 +742,7 @@ def __exit__(self, type, value, traceback): _BACKEND = self._old_backend # always kill nrniv processes for good measure - if self.n_procs > 1: + if (not self.mpi_comm_spawn) and self.n_procs > 1: kill_proc_name('nrniv') def simulate(self, net, tstop, dt, n_trials, postproc=False): @@ -781,7 +781,6 @@ def simulate(self, net, tstop, dt, n_trials, postproc=False): if self.mpi_comm_spawn: sim_data = spawn_job( - parent_comm=self._selfcomm, command=self.command, obj=sim_objs, n_procs=self.n_procs,