Skip to content

Commit

Permalink
speeding up add function for large compounds (#1107)
Browse files Browse the repository at this point in the history
* updated add routine to compose bond graphs of a list together before composing with main compound

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change flattening routine for list, as it was causing a failure in the add remove bond test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* some how left a conflict statement in...not sure why it is causing a conflict to begin with

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added in sped up routine for loading mdtraj trajectories

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added in sped up routine for loading parmed trajectories

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added new tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added tests for helper functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added condense function for Compounds

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added tests to condense function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* addressed comments from cal

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor changes in codes and arrangmenet

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix unit test

---------

Co-authored-by: Christopher Iacovella <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Christopher Iacovella <[email protected]>
Co-authored-by: Co Quach <[email protected]>
  • Loading branch information
5 people committed Apr 28, 2023
1 parent 1211260 commit 08918c7
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 17 deletions.
133 changes: 123 additions & 10 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ele
import networkx as nx
import numpy as np
from boltons.setutils import IndexedSet
from ele.element import Element, element_from_name, element_from_symbol
from ele.exceptions import ElementError
from treelib import Tree
Expand Down Expand Up @@ -811,8 +812,8 @@ def add(
----------
new_child : mb.Compound or list-like of mb.Compound
The object(s) to be added to this Compound.
label : str, optional, default None
A descriptive string for the part.
label : str, or list-like of str, optional, default None
A descriptive string for the part; if a list, must be the same length/shape as new_child.
containment : bool, optional, default=True
Add the part to self.children.
replace : bool, optional, default=True
Expand All @@ -829,11 +830,56 @@ def add(
to add Compounds to an existing rigid body.
"""
# Support batch add via lists, tuples and sets.
# If iterable, we will first compose all the bondgraphs of individual
# Compounds in the list for efficiency
from mbuild.port import Port

if isinstance(new_child, Iterable) and not isinstance(new_child, str):
for child in new_child:
self.add(child, reset_rigid_ids=reset_rigid_ids)
compound_list = [c for c in _flatten_list(new_child)]
if label is not None and isinstance(label, (list, tuple)):
label_list = [c for c in _flatten_list(label)]
if len(label_list) != len(compound_list):
raise ValueError(
"The list-like object for label must be the same length as"
"the list-like object of child Compounds. "
f"total length of labels: {len(label_list)}, new_child: {len(new_child)}."
)
temp_bond_graphs = []
for child in compound_list:
# create a list of bond graphs of the children to add
if containment:
if child.bond_graph and not isinstance(self, Port):
temp_bond_graphs.append(child.bond_graph)

# compose children bond_graphs; make sure we actually have graphs to compose
children_bond_graph = None
if len(temp_bond_graphs) != 0:
children_bond_graph = nx.compose_all(temp_bond_graphs)

if (
temp_bond_graphs
and not isinstance(self, Port)
and children_bond_graph is not None
):
# If anything is added at self level, it is no longer a particle
# search for self in self.root.bond_graph and remove self
if self.root.bond_graph.has_node(self):
self.root.bond_graph.remove_node(self)
# compose the bond graph of all the children with the root
self.root.bond_graph = nx.compose(
self.root.bond_graph, children_bond_graph
)
for i, child in enumerate(compound_list):
child.bond_graph = None
if label is not None:
self.add(
child,
label=label_list[i],
reset_rigid_ids=reset_rigid_ids,
)
else:
self.add(child, reset_rigid_ids=reset_rigid_ids)

return

if not isinstance(new_child, Compound):
Expand Down Expand Up @@ -1859,9 +1905,9 @@ def _visualize_nglview(self, show_ports=False, color_scheme={}):
mdtraj = import_("mdtraj")
from mdtraj.geometry.sasa import _ATOMIC_RADII

remove_digits = lambda x: "".join(
i for i in x if not i.isdigit() or i == "_"
)
def remove_digits(x):
return "".join(i for i in x if not i.isdigit() or i == "_")

for particle in self.particles():
particle.name = remove_digits(particle.name).upper()
if not particle.name:
Expand Down Expand Up @@ -1899,6 +1945,58 @@ def _visualize_nglview(self, show_ports=False, color_scheme={}):
overwrite_nglview_default(widget)
return widget

def condense(self, inplace=True):
"""Condense the hierarchical structure of the Compound to the level of molecules.
Modify the mBuild Compound to become a Compound with 3 distinct levels in the hierarchy.
The top level container (self), contains molecules (i.e., connected Compounds) and the
third level represents Particles (i.e., Compounds with no children).
If the system contains a Particle(s) without any connections to other Compounds, it will
appear in the 2nd level (with the top level self as a parent).
Parameter
---------
inplace : bool, optional, default=True
Option to perform the condense operation inplace or return a copy
Return
------
self : mb.Compound or None
return a condensed Compound if inplace is False.
"""
# temporary list of components
comp_list = []
connected_subgraph = self.root.bond_graph.connected_components()

for molecule in connected_subgraph:
if len(molecule) == 1:
ancestors = [molecule[0]]
else:
ancestors = IndexedSet(molecule[0].ancestors())
for particle in molecule[1:]:
# This works because the way in which particle.ancestors is
# traversed, the lower level will be in the front.
# The intersection will be left at the end,
# ancestor of the first particle is used as reference.
# Hence, this called will return the lowest-level Compound
# that is a molecule
ancestors = ancestors.intersection(
IndexedSet(particle.ancestors())
)

"""Parse molecule information"""
molecule_tag = ancestors[0]
comp_list.append(clone(molecule_tag))
if inplace:
for child in [self.children]:
# Need to handle the case when child is a port
self.remove(child)
self.add(comp_list)
else:
new_compound = Compound(name=self.name)
new_compound.add(comp_list)
return new_compound

def flatten(self, inplace=True):
"""Flatten the hierarchical structure of the Compound.
Expand All @@ -1913,7 +2011,7 @@ def flatten(self, inplace=True):
Return
------
self : mb.Compound or None
return a flatten Compound if inplace is False.
return a flattened Compound if inplace is False.
"""
ports_list = list(self.all_ports())
children_list = list(self.children)
Expand Down Expand Up @@ -2554,8 +2652,10 @@ def _energy_minimize_openbabel(
f"Cannot create a constraint between a Particle and itself: {p1} {p2} ."
)

pid_1 = particle_idx[id(p1)] + 1 # openbabel indices start at 1
pid_2 = particle_idx[id(p2)] + 1 # openbabel indices start at 1
# openbabel indices start at 1
pid_1 = particle_idx[id(p1)] + 1
# openbabel indices start at 1
pid_2 = particle_idx[id(p2)] + 1
dist = (
con_temp[1] * 10.0
) # obenbabel uses angstroms, not nm, convert to angstroms
Expand Down Expand Up @@ -3421,3 +3521,16 @@ def _clone_bonds(self, clone_of=None):


Particle = Compound


def _flatten_list(c_list):
"""Flatten a list.
Helper function to flatten a list that may be nested, e.g. [comp1, [comp2, comp3]].
"""
if isinstance(c_list, Iterable) and not isinstance(c_list, str):
for c in c_list:
if isinstance(c, Iterable) and not isinstance(c, str):
yield from _flatten_list(c)
else:
yield c
45 changes: 38 additions & 7 deletions mbuild/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,19 +554,23 @@ def from_parmed(
chains[residue.chain].append(residue)

# Build up compound
chain_list = []
for chain, residues in chains.items():
if len(chain) > 1:
chain_compound = mb.Compound()
compound.add(chain_compound, chain_id)
chain_list.append(chain_compound)
else:
chain_compound = compound
res_list = []
for residue in residues:
if infer_hierarchy:
residue_compound = mb.Compound(name=residue.name)
chain_compound.add(residue_compound)
parent_compound = residue_compound
res_list.append(residue_compound)
else:
parent_compound = chain_compound
atom_list = []
atom_label_list = []
for atom in residue.atoms:
# Angstrom to nm
pos = np.array([atom.xx, atom.xy, atom.xz]) / 10
Expand All @@ -577,8 +581,14 @@ def from_parmed(
new_atom = mb.Particle(
name=str(atom.name), pos=pos, element=element
)
parent_compound.add(new_atom, label="{0}[$]".format(atom.name))
atom_list.append(new_atom)
atom_label_list.append("{0}[$]".format(atom.name))
atom_mapping[atom] = new_atom
parent_compound.add(atom_list, label=atom_label_list)
if infer_hierarchy:
chain_compound.add(res_list)
if len(chain) > 1:
compound.add(chain_list)

# Infer bonds information
for bond in structure.bonds:
Expand Down Expand Up @@ -650,19 +660,29 @@ def from_trajectory(
compound = mb.Compound()

atom_mapping = dict()
# temporary lists to speed up add to the compound
chains_list = []
chains_list_label = []

for chain in traj.topology.chains:
if traj.topology.n_chains > 1:
chain_compound = mb.Compound()
compound.add(chain_compound, "chain[$]")
chains_list.append(chain_compound)
chains_list_label.append("chain[$]")
else:
chain_compound = compound

res_list = []
for res in chain.residues:
if infer_hierarchy:
res_compound = mb.Compound(name=res.name)
chain_compound.add(res_compound)
parent_cmpd = res_compound
res_list.append(res_compound)
else:
parent_cmpd = chain_compound

atom_list = []
atom_label_list = []
for atom in res.atoms:
try:
element = element_from_atomic_number(
Expand All @@ -675,9 +695,17 @@ def from_trajectory(
pos=traj.xyz[frame, atom.index],
element=element,
)
parent_cmpd.add(new_atom, label="{0}[$]".format(atom.name))
atom_list.append(new_atom)
atom_label_list.append("{0}[$]".format(atom.name))
atom_mapping[atom] = new_atom

parent_cmpd.add(atom_list, label=atom_label_list)

if infer_hierarchy:
chain_compound.add(res_list)
if traj.topology.n_chains > 1:
compound.add(chains_list, label=chains_list_label)

for mdtraj_atom1, mdtraj_atom2 in traj.topology.bonds:
atom1 = atom_mapping[mdtraj_atom1]
atom2 = atom_mapping[mdtraj_atom2]
Expand Down Expand Up @@ -846,13 +874,16 @@ def from_rdkit(rdkit_mol, compound=None, coords_only=False, smiles_seed=0):
else:
comp = compound

part_list = []
for i, atom in enumerate(mymol.GetAtoms()):
part = mb.Particle(
name=atom.GetSymbol(),
element=element_from_atomic_number(atom.GetAtomicNum()),
pos=xyz[i],
)
comp.add(part)
part_list.append(part)

comp.add(part_list)

for bond in mymol.GetBonds():
comp.add_bond(
Expand Down
Loading

0 comments on commit 08918c7

Please sign in to comment.