Skip to content

Commit

Permalink
RMSD with multicore and timeout functionality (#113)
Browse files Browse the repository at this point in the history
Parallel implementation with timeout capabilities (based on pebble).

---------

Co-authored-by: Rocco Meli <[email protected]>
  • Loading branch information
Jnelen and RMeli authored Jun 23, 2024
1 parent 8ce0ede commit 29e1fd6
Show file tree
Hide file tree
Showing 8 changed files with 644 additions and 3 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
## Version 0.9.0

Date: XX/YY/ZZZ
Contributors: @RMeli
Contributors: @RMeli, @Jnelen

### Added

* `--version` CLI option [PR # | @RMeli]
* `--version` CLI option [PR #131 | @RMeli]
* `prmsdwrapper`, a first implementation of a multicore version of `rmsdwrapper`. It also supports a timeout functionality. [PR #113 | @Jnelen]

## Version 0.8.0

Expand Down
3 changes: 3 additions & 0 deletions devtools/conda-envs/spyrmsd-test-rdkit-all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies:
- networkx>=2
- rustworkx

# Parallel
- pebble

# Chemistry
- rdkit

Expand Down
3 changes: 3 additions & 0 deletions devtools/conda-envs/spyrmsd-test-rdkit-nogt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ dependencies:
- networkx>=2
- rustworkx

# Parallel
- pebble

# Chemistry
- rdkit

Expand Down
139 changes: 139 additions & 0 deletions spyrmsd/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
try:
from pebble import ProcessPool
except ImportError:
errmsg = (
"Parallel execution of SPyRMSD (`prmsdwrapper`) requires `pebble`."
"Please install `pebble`: https://github.com/noxdafox/pebble"
)
raise ImportError(errmsg)

import os
import warnings
from concurrent.futures import TimeoutError
from functools import partial
from typing import List, Optional, Union

import numpy as np

from spyrmsd import molecule
from spyrmsd.rmsd import rmsdwrapper


def prmsdwrapper(
molrefs: Union[molecule.Molecule, List[molecule.Molecule]],
mols: Union[molecule.Molecule, List[molecule.Molecule]],
symmetry: bool = True,
center: bool = False,
minimize: bool = False,
strip: bool = True,
cache: bool = True,
num_workers: Optional[int] = None,
timeout: Optional[float] = None,
chunksize: int = 1,
) -> List[float]:
"""
Compute RMSD between two molecules with a timeout.
Parameters
----------
molrefs: Union[molecule.Molecule, List[molecule.Molecule]]
Reference molecule
mols: Union[molecule.Molecule, List[molecule.Molecule]]
Molecules to compare to reference molecule
symmetry: bool, optional
Symmetry-corrected RMSD (using graph isomorphism)
center: bool, optional
Center molecules at origin
minimize: bool, optional
Minimised RMSD (using the quaternion polynomial method)
strip: bool, optional
Strip hydrogen atoms
cache: bool, optional
Cache graph isomorphisms
num_workers: int
Amount of processor to use for the parallel calculations
timeout: float, optional
After how many seconds to stop the RMSD calculations
chunksize: int, optional
How many molecules to handle per child process
Returns
-------
List[float]
RMSDs
"""

# Ensure the num_workers is less or equal than the max number of CPUs.
# Silencing MyPy since os.cpu_count() can return None
if num_workers is None:
num_workers = os.cpu_count()
num_workers = min(num_workers, os.cpu_count()) # type: ignore[type-var]

if chunksize > 1 and timeout is not None:
# When this is not enforced, it can lead to unexpected results (output list length not matching the input list for example).
# To ensure correctness we force the chunksize to be 1 to avoid potential correctness problems.
warnings.warn(
"When using the timeout feature, a chunksize of 1 is required. The chunksize is set to 1 automatically in order to continue the calculations"
)
chunksize = 1

# Cast the molecules to lists if they aren't already
if not isinstance(molrefs, list):
molrefs = [molrefs]
if not isinstance(mols, list):
mols = [mols]

# Match the length of the molref
if len(molrefs) == 1 and len(molrefs) < len(mols):
molrefs = molrefs * len(mols)

# Ensure molrefs and mols have the same len
if not len(molrefs) == len(mols):
raise ValueError("The 'mols' and 'molrefs' lists have different lengths.")

results = []

timeoutCounter = 0
errorCounter = 0

with ProcessPool(max_workers=num_workers) as pool:
rsmd_partial = partial(
rmsdwrapper,
symmetry=symmetry,
center=center,
minimize=minimize,
strip=strip,
cache=cache,
)

future = pool.map(
rsmd_partial, molrefs, mols, timeout=timeout, chunksize=chunksize
)
iterator = future.result()

# See https://pebble.readthedocs.io/en/latest/#pools
while True:
try:
result = next(iterator)
results.append(result[0])
except StopIteration:
break
except TimeoutError:
timeoutCounter += 1

# Upon timeout, the whole chunk fails. To ensure the length and order of the output is maintained we add np.nan for the whole chunk
# More information regarding pebble error handling: https://github.com/noxdafox/pebble/issues/132#issuecomment-2105267462
results += [np.nan] * chunksize
except Exception:
errorCounter += 1
results.append(np.nan)

if timeoutCounter + errorCounter > 0:
# Calculate total number of np.nan
failedCompoundsTotal = np.count_nonzero(np.isnan(results))

warnings.warn(
f"{failedCompoundsTotal} compounds failed to process successfully and have been added as 'np.nan'."
+ f" {errorCounter} compounds raised an error, {timeoutCounter} chunks timed out"
)

return results
2 changes: 1 addition & 1 deletion spyrmsd/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def rmsdwrapper(
cache: bool = True,
) -> Any:
"""
Compute RMSD between two molecule.
Compute RMSD between two molecules.
Parameters
----------
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def pyridine(molpath):
return Mol(mol, "pyridine", 11, 11, 5)


@pytest.fixture
def muparfostat(molpath):
mol = io.loadmol(os.path.join(molpath, "muparfostat.sdf"))
return Mol(mol, "muparfostat", 78, 80, 33)


@pytest.fixture(
params=[
# (name, n_atoms, n_bonds, n_h)
Expand Down
Loading

0 comments on commit 29e1fd6

Please sign in to comment.