Skip to content

Commit

Permalink
[contacts.ContactGroup.select_by_residues] refactor from to_new_Conta…
Browse files Browse the repository at this point in the history
…ctGroup
  • Loading branch information
gph82 committed Mar 13, 2024
1 parent 2a3b482 commit 21cb1f2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
75 changes: 39 additions & 36 deletions mdciao/contacts/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5992,32 +5992,35 @@ def select_by_frames(self, frames) -> ContactPair:
Control what frames of the trajectory data
gets used in the returned ContactGroups. Several modes
of input are possible.
* integer `n`:
select the first `n` frames
of each trajectory. If `n` is negative,
then select the last `n` frames of each
trajectory. If a trajectory has
less than `n` frames, all frames are selected.
* dict:
keyed with trajectory indices, valued
with a list of trajectory frames. E.g.
if `frames = {2 : [101,100], 0: [10, 20]}`,
then the new ContactGroup has two trajectories
which consist of old trajectories
2 and 0, with the frames 101,100 and 10,20,
respectively. The output order corresponds
the input order both in terms of keys and
values of the input dictionary.
* list of pairs of integers :
individual frames
of individual trajectories merted into
a single ContactGroup, e.g.
[[ti,fj],[tk,fl],[tm,fn]] means the
new ContactGroup has three frames
* frame j of trajectory i
* frame k of trajectory l
* frame n of trajectory m
* integer `n`:
select the first `n` frames
of each trajectory. If `n` is negative,
then select the last `n` frames of each
trajectory. If a trajectory has
less than `n` frames, all frames are selected.
* dict:
keyed with trajectory indices, valued
with a list of trajectory frames. E.g.
if `frames = {2 : [101,100], 0: [10, 20]}`,
then the new ContactGroup has two trajectories
which consist of old trajectories
2 and 0, with the frames 101,100 and 10,20,
respectively. The output order corresponds
the input order both in terms of keys and
values of the input dictionary.
* list of pairs of integers:
individual frames
of individual trajectories merged into
a single ContactGroup, e.g.
>>> frames = [[i,j],
>>> [k,l],
>>> [m,n]]
means the new ContactGroup has three frames
* frame j of trajectory i
* frame k of trajectory l
* frame n of trajectory m
Returns
-------
Expand Down Expand Up @@ -6118,23 +6121,23 @@ def select_by_frames(self, frames) -> ContactPair:
# name=self.name # Unsure about what's best here, keep it or modify it. It's not used anywhere
)

def to_new_ContactGroup(self,
CSVexpression=None,
residue_indices=None,
allow_multiple_matches=False, merge=True,
keep_interface=True,
n_residues=1):
def select_by_residues(self,
CSVexpression=None,
residue_indices=None,
allow_multiple_matches=False, merge=True,
keep_interface=True,
n_residues=1):
r"""
Creates a new :obj:`ContactGroup` from this une using a CSV expression to filter for residues
Return a copy this :obj:`ContactGroup`, but with a sub-selection of :obj:`ContactGroup.contact_pairs` based on residues.
The returned :obj:`ContactGroup` has the same trajectories and frames as the original.
The filtering of ContactPairs against `CSVexpression` or `residue_indices`
can be such that:
The filtering of ContactPairs is done using `CSVexpression` or `residue_indices`
so that:
* one residue match per ContactPair is enough, or
* both residues of the ContactPair need to match
for the ContactPair to be selected for the new ContactGroup.
See `n_residues` for more info.
Parameters
----------
CSVexpression : str or None, default is None
Expand Down
18 changes: 9 additions & 9 deletions tests/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,7 +1821,7 @@ def test_repframe_w_traj_violines_many_frames_just_runs(self):
assert len(repframes)==len(RMSDd)==len(values)==len(trajs)==10
assert isinstance(trajs[0], md.Trajectory)

def test_to_new_ContactGroup_CSV(self):
def test_select_by_residues_CSV(self):
CG = _mdcsites([{"name": "test_random",
"pairs": {"residx": [[100, 200],
[100, 300],
Expand All @@ -1837,15 +1837,15 @@ def test_to_new_ContactGroup_CSV(self):
keys = [str(CG.top.residue(ii)) for ii in [200,10, 1]]
CSV = ','.join(keys)

new_CG : contacts.ContactGroup = CG.to_new_ContactGroup(CSVexpression=CSV)
new_CG : contacts.ContactGroup = CG.select_by_residues(CSVexpression=CSV)
assert new_CG.n_ctcs == 4
assert new_CG._contacts[0] is CG._contacts[0]
assert new_CG._contacts[1] is CG._contacts[2]
assert new_CG._contacts[2] is CG._contacts[3]
assert new_CG._contacts[3] is CG._contacts[5]
assert isinstance(new_CG,contacts.ContactGroup)

new_CG_dict = CG.to_new_ContactGroup(CSVexpression=CSV, merge=False)
new_CG_dict = CG.select_by_residues(CSVexpression=CSV, merge=False)

self.assertSequenceEqual(list(new_CG_dict.keys()),CSV.split(","))

Expand All @@ -1860,7 +1860,7 @@ def test_to_new_ContactGroup_CSV(self):
assert new_CG_dict[keys[2]] is None


def test_to_new_ContactGroup_residue_indices(self):
def test_select_by_residues_residue_indices(self):
CG = _mdcsites([{"name": "test_random",
"pairs": {"residx": [[100, 200],
[100, 300],
Expand All @@ -1875,14 +1875,14 @@ def test_to_new_ContactGroup_residue_indices(self):
figures=False)["test_random"]

residue_indices = [10, 40, 1]
new_CG : contacts.ContactGroup = CG.to_new_ContactGroup(residue_indices=residue_indices)
new_CG : contacts.ContactGroup = CG.select_by_residues(residue_indices=residue_indices)
assert new_CG.n_ctcs == 3
assert new_CG._contacts[0] is CG._contacts[2]
assert new_CG._contacts[1] is CG._contacts[4]
assert new_CG._contacts[2] is CG._contacts[5]
assert isinstance(new_CG,contacts.ContactGroup)

new_CG_dict = CG.to_new_ContactGroup(residue_indices=residue_indices, merge=False)
new_CG_dict = CG.select_by_residues(residue_indices=residue_indices, merge=False)
self.assertSequenceEqual(list(new_CG_dict.keys()),residue_indices)
assert new_CG_dict[residue_indices[0]].n_ctcs == 2
assert new_CG_dict[residue_indices[0]]._contacts[0] is CG._contacts[2]
Expand All @@ -1894,7 +1894,7 @@ def test_to_new_ContactGroup_residue_indices(self):

assert new_CG_dict[residue_indices[2]] is None

def test_to_new_ContactGroup_residue_indices_n_residues_is_2(self):
def test_to_select_by_residues_residue_indices_n_residues_is_2(self):
CG = _mdcsites([{"name": "test_random",
"pairs": {"residx": [[100, 200],
[100, 300],
Expand All @@ -1910,7 +1910,7 @@ def test_to_new_ContactGroup_residue_indices_n_residues_is_2(self):
figures=False)["test_random"]

residue_indices = [10, 40, 20, 50]
new_CG : contacts.ContactGroup = CG.to_new_ContactGroup(residue_indices=residue_indices, n_residues=2)
new_CG : contacts.ContactGroup = CG.select_by_residues(residue_indices=residue_indices, n_residues=2)

assert isinstance(new_CG, contacts.ContactGroup)
assert new_CG.n_ctcs == 3
Expand All @@ -1919,7 +1919,7 @@ def test_to_new_ContactGroup_residue_indices_n_residues_is_2(self):
assert new_CG._contacts[2] is CG._contacts[5]


new_CG_dict = CG.to_new_ContactGroup(residue_indices=residue_indices, merge=False, n_residues=2)
new_CG_dict = CG.select_by_residues(residue_indices=residue_indices, merge=False, n_residues=2)
self.assertSequenceEqual(list(new_CG_dict.keys()),residue_indices)
assert new_CG_dict[residue_indices[0]].n_ctcs == 1
assert new_CG_dict[residue_indices[0]]._contacts[0] is CG._contacts[2]
Expand Down

0 comments on commit 21cb1f2

Please sign in to comment.