Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'MDAnalysis.analysis.nucleicacids' parallelization #4727

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
18 changes: 14 additions & 4 deletions package/MDAnalysis/analysis/nucleicacids.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

import MDAnalysis as mda
from .distances import calc_bonds
from .base import AnalysisBase, Results
from .base import AnalysisBase, ResultsGroup
from MDAnalysis.core.groups import Residue, ResidueGroup


Expand Down Expand Up @@ -161,6 +161,12 @@ class NucPairDist(AnalysisBase):
helper for selecting atom pairs for distance analysis.
"""
talagayev marked this conversation as resolved.
Show resolved Hide resolved

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ('serial', 'multiprocessing', 'dask',)

_s1: mda.AtomGroup
_s2: mda.AtomGroup
_n_sel: int
talagayev marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -276,7 +282,7 @@ def select_strand_atoms(
return (sel1, sel2)

def _prepare(self) -> None:
self._res_array: np.ndarray = np.zeros(
self.results.distances: np.ndarray = np.zeros(
[self.n_frames, self._n_sel]
)

Expand All @@ -285,13 +291,17 @@ def _single_frame(self) -> None:
self._s1.positions, self._s2.positions
)

self._res_array[self._frame_index, :] = dist
self.results.distances[self._frame_index, :] = dist

def _conclude(self) -> None:
self.results['distances'] = self._res_array
self.results['pair_distances'] = self.results['distances']
# TODO: remove pair_distances in 3.0.0

def _get_aggregator(self):
return ResultsGroup(lookup={
'distances': ResultsGroup.ndarray_vstack,
}
)

class WatsonCrickDist(NucPairDist):
r"""
Expand Down
8 changes: 8 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from MDAnalysis.analysis.hydrogenbonds.hbond_analysis import (
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.nucleicacids import NucPairDist
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -141,3 +142,10 @@ def client_DSSP(request):
@pytest.fixture(scope='module', params=params_for_cls(HydrogenBondAnalysis))
def client_HydrogenBondAnalysis(request):
return request.param


# MDAnalysis.analysis.nucleicacids

@pytest.fixture(scope="module", params=params_for_cls(NucPairDist))
def client_NucPairDist(request):
return request.param
12 changes: 6 additions & 6 deletions testsuite/MDAnalysisTests/analysis/test_nucleicacids.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def test_empty_ag_error(strand):


@pytest.fixture(scope='module')
def wc_rna(strand):
def wc_rna(strand, client_NucPairDist):
strand1 = ResidueGroup([strand.residues[0], strand.residues[21]])
strand2 = ResidueGroup([strand.residues[1], strand.residues[22]])

WC = WatsonCrickDist(strand1, strand2)
WC.run()
WC.run(**client_NucPairDist)
return WC


Expand Down Expand Up @@ -114,23 +114,23 @@ def test_wc_dis_results_keyerrs(wc_rna, key):
wc_rna.results[key]


def test_minor_dist(strand):
def test_minor_dist(strand, client_NucPairDist):
strand1 = ResidueGroup([strand.residues[2], strand.residues[19]])
strand2 = ResidueGroup([strand.residues[16], strand.residues[4]])

MI = MinorPairDist(strand1, strand2)
MI.run()
MI.run(**client_NucPairDist)

assert MI.results.distances[0, 0] == approx(15.06506, rel=1e-3)
assert MI.results.distances[0, 1] == approx(3.219116, rel=1e-3)


def test_major_dist(strand):
def test_major_dist(strand, client_NucPairDist):
strand1 = ResidueGroup([strand.residues[1], strand.residues[4]])
strand2 = ResidueGroup([strand.residues[11], strand.residues[8]])

MA = MajorPairDist(strand1, strand2)
MA.run()
MA.run(**client_NucPairDist)

assert MA.results.distances[0, 0] == approx(26.884272, rel=1e-3)
assert MA.results.distances[0, 1] == approx(13.578535, rel=1e-3)
Loading