Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove rigid body data structure from Compound #1203

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 3 additions & 269 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,10 @@ class Compound(object):
compound is the root of the containment hierarchy.
referrers : set
Other compounds that reference this part with labels.
rigid_id : int, default=None
The ID of the rigid body that this Compound belongs to. Only Particles
(the bottom of the containment hierarchy) can have integer values for
`rigid_id`. Compounds containing rigid particles will always have
`rigid_id == None`. See also `contains_rigid`.
boundingbox : mb.Box
The bounds (xmin, xmax, ymin, ymax, zmin, zmax) of particles in Compound
center
contains_rigid
mass
max_rigid_id
n_particles
n_bonds
root
Expand Down Expand Up @@ -183,10 +176,6 @@ def __init__(

self.port_particle = port_particle

self._rigid_id = None
self._contains_rigid = False
self._check_if_contains_rigid_bodies = False

self.element = element
if mass and float(mass) < 0.0:
raise ValueError("Cannot set a Compound mass value less than zero")
Expand Down Expand Up @@ -583,230 +572,6 @@ def charge(self, value):
"not at the bottom of the containment hierarchy."
)

@property
def rigid_id(self):
"""Get the rigid_id of the Compound."""
return self._rigid_id

@rigid_id.setter
def rigid_id(self, value):
if self._contains_only_ports():
self._rigid_id = value
for ancestor in self.ancestors():
ancestor._check_if_contains_rigid_bodies = True
else:
raise AttributeError(
"rigid_id is immutable for Compounds that are "
"not at the bottom of the containment hierarchy."
)

@property
def contains_rigid(self):
"""Return True if the Compound contains rigid bodies.

If the Compound contains any particle with a rigid_id != None
then contains_rigid will return True. If the Compound has no
children (i.e. the Compound resides at the bottom of the containment
hierarchy) then contains_rigid will return False.

Returns
-------
bool,
True if the Compound contains any particle with a rigid_id != None

Notes
-----
The private variable '_check_if_contains_rigid_bodies' is used to help
cache the status of 'contains_rigid'.
If '_check_if_contains_rigid_bodies' is False, then the rigid body
containment of the Compound has not changed, and the particle tree is
not traversed, boosting performance.
"""
if self._check_if_contains_rigid_bodies:
self._check_if_contains_rigid_bodies = False
if any(p.rigid_id is not None for p in self._particles()):
self._contains_rigid = True
else:
self._contains_rigid = False
return self._contains_rigid

@property
def max_rigid_id(self):
"""Return the maximum rigid body ID contained in the Compound.

This is usually used by compound.root to determine the maximum
rigid_id in the containment hierarchy.

Returns
-------
int or None
The maximum rigid body ID contained in the Compound. If no
rigid body IDs are found, None is returned
"""
try:
return max(
[p.rigid_id for p in self.particles() if p.rigid_id is not None]
)
except ValueError:
return

def rigid_particles(self, rigid_id=None):
"""Generate all particles in rigid bodies.

If a rigid_id is specified, then this function will only yield particles
with a matching rigid_id.

Parameters
----------
rigid_id : int, optional
Include only particles with this rigid body ID

Yields
------
mb.Compound
The next particle with a rigid_id that is not None, or the next
particle with a matching rigid_id if specified
"""
for particle in self.particles():
if rigid_id is not None:
if particle.rigid_id == rigid_id:
yield particle
else:
if particle.rigid_id is not None:
yield particle

def label_rigid_bodies(self, discrete_bodies=None, rigid_particles=None):
"""Designate which Compounds should be treated as rigid bodies.

If no arguments are provided, this function will treat the compound
as a single rigid body by providing all particles in `self` with the
same rigid_id. If `discrete_bodies` is not None, each instance of
a Compound with a name found in `discrete_bodies` will be treated as a
unique rigid body. If `rigid_particles` is not None, only Particles
(Compounds at the bottom of the containment hierarchy) matching this
name will be considered part of the rigid body.

Parameters
----------
discrete_bodies : str or list of str, optional, default=None
Name(s) of Compound instances to be treated as unique rigid bodies.
Compound instances matching this (these) name(s) will be provided
with unique rigid_ids
rigid_particles : str or list of str, optional, default=None
Name(s) of Compound instances at the bottom of the containment
hierarchy (Particles) to be included in rigid bodies. Only Particles
matching this (these) name(s) will have their rigid_ids altered to
match the rigid body number.

Examples
--------
Creating a rigid benzene

>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.label_rigid_bodies()

Creating a semi-rigid benzene, where only the carbons are treated as
a rigid body

>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.label_rigid_bodies(rigid_particles='C')

Create a box of rigid benzenes, where each benzene has a unique rigid
body ID.

>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.name = 'Benzene'
>>> filled = mb.fill_box(benzene,
... n_compounds=10,
... box=[0, 0, 0, 4, 4, 4])
>>> filled.label_rigid_bodies(distinct_bodies='Benzene')

Create a box of semi-rigid benzenes, where each benzene has a unique
rigid body ID and only the carbon portion is treated as rigid.

>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.name = 'Benzene'
>>> filled = mb.fill_box(benzene,
... n_compounds=10,
... box=[0, 0, 0, 4, 4, 4])
>>> filled.label_rigid_bodies(distinct_bodies='Benzene',
... rigid_particles='C')
"""
if discrete_bodies is not None:
if isinstance(discrete_bodies, str):
discrete_bodies = [discrete_bodies]
if rigid_particles is not None:
if isinstance(rigid_particles, str):
rigid_particles = [rigid_particles]

if self.root.max_rigid_id is not None:
rigid_id = self.root.max_rigid_id + 1
warn(
f"{rigid_id} rigid bodies already exist. Incrementing 'rigid_id'"
f"starting from {rigid_id}."
)
else:
rigid_id = 0

for successor in self.successors():
if discrete_bodies and successor.name not in discrete_bodies:
continue
for particle in successor.particles():
if rigid_particles and particle.name not in rigid_particles:
continue
particle.rigid_id = rigid_id
if discrete_bodies:
rigid_id += 1

def unlabel_rigid_bodies(self):
"""Remove all rigid body labels from the Compound."""
self._check_if_contains_rigid_bodies = True
for child in self.children:
child._check_if_contains_rigid_bodies = True
for particle in self.particles():
particle.rigid_id = None

def _increment_rigid_ids(self, increment):
"""Increment the rigid_id of all rigid Particles in a Compound.

Adds `increment` to the rigid_id of all Particles in `self` that
already have an integer rigid_id.
"""
for particle in self.particles():
if particle.rigid_id is not None:
particle.rigid_id += increment

def _reorder_rigid_ids(self):
"""Reorder rigid body IDs ensuring consecutiveness.

Primarily used internally to ensure consecutive rigid_ids following
removal of a Compound.
"""
max_rigid = self.max_rigid_id
unique_rigid_ids = sorted(
set([p.rigid_id for p in self.rigid_particles()])
)
n_unique_rigid = len(unique_rigid_ids)
if max_rigid and n_unique_rigid != max_rigid + 1:
missing_rigid_id = (
unique_rigid_ids[-1] * (unique_rigid_ids[-1] + 1)
) / 2 - sum(unique_rigid_ids)
for successor in self.successors():
if successor.rigid_id is not None:
if successor.rigid_id > missing_rigid_id:
successor.rigid_id -= 1
if self.rigid_id:
if self.rigid_id > missing_rigid_id:
self.rigid_id -= 1

def add(
self,
new_child,
Expand All @@ -815,7 +580,6 @@ def add(
replace=False,
inherit_periodicity=None,
inherit_box=False,
reset_rigid_ids=True,
check_box_size=True,
):
"""Add a part to the Compound.
Expand All @@ -840,11 +604,6 @@ def add(
Compound being added
inherit_box: bool, optional, default=False
Replace the box of self with the box of the Compound being added
reset_rigid_ids : bool, optional, default=True
If the Compound to be added contains rigid bodies, reset the
rigid_ids such that values remain distinct from rigid_ids
already present in `self`. Can be set to False if attempting
to add Compounds to an existing rigid body.
check_box_size : bool, optional, default=True
Checks and warns if compound box is smaller than its bounding box after adding new_child.
"""
Expand Down Expand Up @@ -894,15 +653,10 @@ def add(
self.add(
child,
label=label_list[i],
reset_rigid_ids=reset_rigid_ids,
check_box_size=False,
)
else:
self.add(
child,
reset_rigid_ids=reset_rigid_ids,
check_box_size=False,
)
self.add(child, check_box_size=False)

return

Expand All @@ -919,13 +673,6 @@ def add(
)
self._mass = 0

if new_child.contains_rigid or new_child.rigid_id is not None:
if self.contains_rigid and reset_rigid_ids:
new_child._increment_rigid_ids(increment=self.max_rigid_id + 1)
self._check_if_contains_rigid_bodies = True
if self.rigid_id is not None:
self.rigid_id = None

# Create children and labels on the first add operation
if self.children is None:
self.children = list()
Expand Down Expand Up @@ -1067,7 +814,7 @@ def _check_if_empty(child):
for particle in particles_to_remove:
_check_if_empty(particle)

# Fix rigid_ids and remove obj from bondgraph
# Remove obj from bondgraph
for removed_part in to_remove:
self._remove(removed_part)

Expand All @@ -1077,11 +824,6 @@ def _check_if_empty(child):
removed_part.parent.children.remove(removed_part)
self._remove_references(removed_part)

# Check and reorder rigid id
for _ in particles_to_remove:
if self.contains_rigid:
self.root._reorder_rigid_ids()

# Remove ghost ports
self._prune_ghost_ports()

Expand Down Expand Up @@ -1148,10 +890,7 @@ def _prune_ghost_ports(self):
self._remove_references(port)

def _remove(self, removed_part):
"""Worker for remove(). Fixes rigid IDs and removes bonds."""
if removed_part.rigid_id is not None:
for ancestor in removed_part.ancestors():
ancestor._check_if_contains_rigid_bodies = True
"""Worker for remove(). Removes bonds."""
if self.root.bond_graph.has_node(removed_part):
for neighbor in nx.neighbors(
self.root.bond_graph.copy(), removed_part
Expand Down Expand Up @@ -3600,12 +3339,7 @@ def _clone(self, clone_of=None, root_container=None):
newone._pos = deepcopy(self._pos)
newone.port_particle = deepcopy(self.port_particle)
newone._box = deepcopy(self._box)
newone._check_if_contains_rigid_bodies = deepcopy(
self._check_if_contains_rigid_bodies
)
newone._periodicity = deepcopy(self._periodicity)
newone._contains_rigid = deepcopy(self._contains_rigid)
newone._rigid_id = deepcopy(self._rigid_id)
newone._charge = deepcopy(self._charge)
newone._mass = deepcopy(self._mass)
if hasattr(self, "index"):
Expand Down
Loading
Loading