diff --git a/gmso/__init__.py b/gmso/__init__.py index 9a5985491..5bda2c48e 100644 --- a/gmso/__init__.py +++ b/gmso/__init__.py @@ -6,11 +6,11 @@ from .core.bond import Bond from .core.bond_type import BondType from .core.box import Box -from .core.dihedral import Dihedral +from .core.dihedral import Dihedral, LayeredDihedral from .core.dihedral_type import DihedralType from .core.element import Element from .core.forcefield import ForceField -from .core.improper import Improper +from .core.improper import Improper, LayeredImproper from .core.improper_type import ImproperType from .core.pairpotential_type import PairPotentialType from .core.subtopology import SubTopology diff --git a/gmso/abc/abstract_connection.py b/gmso/abc/abstract_connection.py index ef8bd4e06..173e35139 100644 --- a/gmso/abc/abstract_connection.py +++ b/gmso/abc/abstract_connection.py @@ -67,7 +67,7 @@ def _get_members_types_or_classes(self, to_return): @root_validator(pre=True) def validate_fields(cls, values): - connection_members = values.get("connection_members") + connection_members = values.get("connection_members", []) if all(isinstance(member, dict) for member in connection_members): connection_members = [ diff --git a/gmso/abc/gmso_base.py b/gmso/abc/gmso_base.py index 957e704da..96d39238b 100644 --- a/gmso/abc/gmso_base.py +++ b/gmso/abc/gmso_base.py @@ -120,3 +120,4 @@ class Config: extra = "forbid" json_encoders = GMSOJSONHandler.json_encoders allow_population_by_field_name = True + validate_assignment = True diff --git a/gmso/core/dihedral.py b/gmso/core/dihedral.py index 6d0910652..7e470e28f 100644 --- a/gmso/core/dihedral.py +++ b/gmso/core/dihedral.py @@ -1,17 +1,21 @@ -from typing import Callable, ClassVar, Optional, Tuple +from typing import Callable, ClassVar, Iterable, Optional, Tuple -from pydantic import Field +from boltons.setutils import IndexedSet +from pydantic import Field, ValidationError, validator from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom from gmso.core.dihedral_type import DihedralType +from gmso.utils.misc import validate_type -class Dihedral(Connection): +class BaseDihedral(Connection): __base_doc__ = """A 4-partner connection between sites. - This is a subclass of the gmso.Connection superclass. - This class has strictly 4 members in its connection_members. + This is a subclass of the gmso.Connection superclass. This class + has strictly 4 members in its connection_members and used as + a base class to define many different forms of a Dihedral. + The connection_type in this class corresponds to gmso.DihedralType. The connectivity of a dihedral is: m1–m2–m3–m4 @@ -21,7 +25,7 @@ class Dihedral(Connection): Notes ----- Inherits some methods from Connection: - __eq__, __repr__, _validate methods + __eq__, _validate methods Additional _validate methods are presented """ @@ -31,19 +35,6 @@ class Dihedral(Connection): ..., description="The 4 atoms involved in the dihedral." ) - dihedral_type_: Optional[DihedralType] = Field( - default=None, description="DihedralType of this dihedral." - ) - - @property - def dihedral_type(self): - return self.__dict__.get("dihedral_type_") - - @property - def connection_type(self): - # ToDo: Deprecate this? - return self.__dict__.get("dihedral_type_") - def equivalent_members(self): """Get a set of the equivalent connection member tuples @@ -96,18 +87,102 @@ def _equivalent_members_hash(self): ) ) + def is_layered(self): + return hasattr(self, "dihedral_types_") + + def __repr__(self): + return ( + f"<{self.__class__.__name__} {self.name},\n " + f"connection_members: {self.connection_members},\n " + f"potential: {str(self.dihedral_types if self.is_layered() else self.dihedral_type)},\n " + f"id: {id(self)}>" + ) + + class Config: + fields = { + "connection_members_": "connection_members", + } + alias_to_fields = { + "connection_members": "connection_members_", + } + + +class Dihedral(BaseDihedral): + __base_doc__ = """A 4-Partner connection between 4 sites with a **single** dihedral type association + + Notes + ----- + This class inherits from BaseDihedral. + """ + dihedral_type_: Optional[DihedralType] = Field( + default=None, description="DihedralType of this dihedral." + ) + + @property + def dihedral_type(self): + return self.__dict__.get("dihedral_type_") + + @property + def connection_type(self): + # ToDo: Deprecate this? + return self.__dict__.get("dihedral_type_") + def __setattr__(self, key, value): if key == "connection_type": - super(Dihedral, self).__setattr__("dihedral_type", value) + super().__setattr__("dihedral_type", value) else: - super(Dihedral, self).__setattr__(key, value) + super().__setattr__(key, value) class Config: fields = { "dihedral_type_": "dihedral_type", - "connection_members_": "connection_members", } alias_to_fields = { "dihedral_type": "dihedral_type_", - "connection_members": "connection_members_", + } + + +class LayeredDihedral(BaseDihedral): + __base_doc__ = """A 4-Partner connection between 4 sites with **multiple** dihedral type associations + + Notes + ----- + This class inherits from BaseDihedral. + """ + + dihedral_types_: Optional[IndexedSet] = Field( + default=None, description="DihedralTypes of this dihedral." + ) + + @property + def dihedral_types(self): + return self.__dict__.get("dihedral_types_") + + @property + def connection_types(self): + # ToDo: Deprecate this? + return self.__dict__.get("dihedral_types_") + + def __setattr__(self, key, value): + if key == "connection_types": + super().__setattr__("dihedral_types", value) + else: + super().__setattr__(key, value) + + @validator("dihedral_types_", pre=True) + def validate_dihedral_types(cls, dihedral_types): + if not isinstance(dihedral_types, Iterable) or isinstance( + dihedral_types, str + ): + raise ValidationError("DihedralTypes should be iterable", cls) + + validate_type(dihedral_types, DihedralType) + return IndexedSet(dihedral_types) + + class Config: + fields = { + "dihedral_types_": "dihedral_types", + } + alias_to_fields = { + "dihedral_types": "dihedral_types_", } diff --git a/gmso/core/improper.py b/gmso/core/improper.py index d27d77617..86dc3e0ce 100644 --- a/gmso/core/improper.py +++ b/gmso/core/improper.py @@ -1,15 +1,17 @@ """Support for improper style connections (4-member connection).""" -from typing import Callable, ClassVar, Optional, Tuple +from typing import Callable, ClassVar, Iterable, Optional, Tuple -from pydantic import Field +from boltons.setutils import IndexedSet +from pydantic import Field, ValidationError, validator from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom from gmso.core.improper_type import ImproperType +from gmso.utils.misc import validate_type -class Improper(Connection): - __base_doc__ = """sA 4-partner connection between sites. +class BaseImproper(Connection): + __base_doc__ = """A 4-partner connection between sites. This is a subclass of the gmso.Connection superclass. This class has strictly 4 members in its connection_members. @@ -39,21 +41,6 @@ class Improper(Connection): "then the three atoms connected to the central site.", ) - improper_type_: Optional[ImproperType] = Field( - default=None, description="ImproperType of this improper." - ) - - @property - def improper_type(self): - """Return Potential object for this connection if it exists.""" - return self.__dict__.get("improper_type_") - - @property - def connection_type(self): - """Return Potential object for this connection if it exists.""" - # ToDo: Deprecate this? - return self.__dict__.get("improper_type_") - def equivalent_members(self): """Get a set of the equivalent connection member tuples. @@ -106,21 +93,83 @@ def _equivalent_members_hash(self): ) ) + class Config: + """Pydantic configuration to link fields to their public attribute.""" + + fields = { + "connection_members_": "connection_members", + } + alias_to_fields = { + "connection_members": "connection_members_", + } + + +class Improper(BaseImproper): + improper_type_: Optional[ImproperType] = Field( + default=None, description="ImproperType of this improper." + ) + + @property + def improper_type(self): + """Return Potential object for this connection if it exists.""" + return self.__dict__.get("improper_type_") + + @property + def connection_type(self): + """Return Potential object for this connection if it exists.""" + # ToDo: Deprecate this? + return self.__dict__.get("improper_type_") + def __setattr__(self, key, value): """Set attribute override to support connection_type key.""" if key == "connection_type": - super(Improper, self).__setattr__("improper_type", value) + super().__setattr__("improper_type", value) else: - super(Improper, self).__setattr__(key, value) + super().__setattr__(key, value) class Config: """Pydantic configuration to link fields to their public attribute.""" fields = { "improper_type_": "improper_type", - "connection_members_": "connection_members", } alias_to_fields = { "improper_type": "improper_type_", - "connection_members": "connection_members_", + } + + +class LayeredImproper(BaseImproper): + improper_types_: Optional[IndexedSet] = Field( + default=None, description="ImproperTypes of this improper." + ) + + @property + def improper_types(self): + return self.__dict__.get("improper_types_") + + @property + def connection_types(self): + # ToDo: Deprecate this? + return self.__dict__.get("improper_types_") + + def __setattr__(self, key, value): + if key == "connection_types": + super().__setattr__("improper_types_", value) + else: + super().__setattr__(key, value) + + @validator("improper_types_", pre=True, always=True) + def validate_improper_types(cls, improper_types): + if not isinstance(improper_types, Iterable): + raise ValidationError("ImproperTypes should be iterable", cls) + + validate_type(improper_types, ImproperType) + return IndexedSet(improper_types) + + class Config: + fields = { + "improper_types_": "improper_types", + } + alias_to_fields = { + "improper_types": "improper_types_", } diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 37bca1241..cc00c9bfe 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -13,9 +13,9 @@ from gmso.core.atom_type import AtomType from gmso.core.bond import Bond from gmso.core.bond_type import BondType -from gmso.core.dihedral import Dihedral +from gmso.core.dihedral import BaseDihedral, Dihedral from gmso.core.dihedral_type import DihedralType -from gmso.core.improper import Improper +from gmso.core.improper import BaseImproper, Improper from gmso.core.improper_type import ImproperType from gmso.core.pairpotential_type import PairPotentialType from gmso.core.parametric_potential import ParametricPotential @@ -520,9 +520,9 @@ def add_connection(self, connection, update_types=True): self._bonds.add(connection) if isinstance(connection, Angle): self._angles.add(connection) - if isinstance(connection, Dihedral): + if isinstance(connection, BaseDihedral): self._dihedrals.add(connection) - if isinstance(connection, Improper): + if isinstance(connection, BaseImproper): self._impropers.add(connection) if update_types: self.update_connection_types() @@ -544,48 +544,78 @@ def update_connection_types(self): -------- gmso.Topology.update_atom_types : Update atom types in the topology. """ + # Here an alternative could be using checking instances of LayeredConnection classes + # But to make it generic, this approach works best + get_connection_type = ( + lambda conn: conn.connection_types + if hasattr(conn, "connection_types") + else conn.connection_type + ) for c in self.connections: - if c.connection_type is None: + connection_type_or_types = get_connection_type(c) + if connection_type_or_types is None: warnings.warn( "Non-parametrized Connection {} detected".format(c) ) - elif not isinstance(c.connection_type, ParametricPotential): + continue + elif not isinstance( + connection_type_or_types, (ParametricPotential, IndexedSet) + ): raise GMSOError( "Non-Potential {} found" - "in Connection {}".format(c.connection_type, c) + "in Connection {}".format(connection_type_or_types, c) ) - elif c.connection_type not in self._connection_types: - c.connection_type.topology = self - self._connection_types[c.connection_type] = c.connection_type - if isinstance(c.connection_type, BondType): - self._bond_types[c.connection_type] = c.connection_type - self._bond_types_idx[c.connection_type] = ( - len(self._bond_types) - 1 - ) - if isinstance(c.connection_type, AngleType): - self._angle_types[c.connection_type] = c.connection_type - self._angle_types_idx[c.connection_type] = ( - len(self._angle_types) - 1 - ) - if isinstance(c.connection_type, DihedralType): - self._dihedral_types[c.connection_type] = c.connection_type - self._dihedral_types_idx[c.connection_type] = ( - len(self._dihedral_types) - 1 - ) - if isinstance(c.connection_type, ImproperType): - self._improper_types[c.connection_type] = c.connection_type - self._improper_types_idx[c.connection_type] = ( - len(self._improper_types) - 1 - ) - elif c.connection_type in self.connection_types: - if isinstance(c.connection_type, BondType): - c.connection_type = self._bond_types[c.connection_type] - if isinstance(c.connection_type, AngleType): - c.connection_type = self._angle_types[c.connection_type] - if isinstance(c.connection_type, DihedralType): - c.connection_type = self._dihedral_types[c.connection_type] - if isinstance(c.connection_type, ImproperType): - c.connection_type = self._improper_types[c.connection_type] + if not isinstance(connection_type_or_types, IndexedSet): + connection_type_or_types = [connection_type_or_types] + for connection_type in connection_type_or_types: + if connection_type not in self._connection_types: + connection_type.topology = self + self._connection_types[connection_type] = connection_type + if isinstance(connection_type, BondType): + self._bond_types[connection_type] = connection_type + self._bond_types_idx[connection_type] = ( + len(self._bond_types) - 1 + ) + if isinstance(connection_type, AngleType): + self._angle_types[connection_type] = connection_type + self._angle_types_idx[connection_type] = ( + len(self._angle_types) - 1 + ) + if isinstance(connection_type, DihedralType): + self._dihedral_types[connection_type] = connection_type + self._dihedral_types_idx[connection_type] = ( + len(self._dihedral_types) - 1 + ) + if isinstance(connection_type, ImproperType): + self._improper_types[ + connection_type + ] = c.connection_type + self._improper_types_idx[connection_type] = ( + len(self._improper_types) - 1 + ) + elif connection_type in self.connection_types: + if isinstance(connection_type, BondType): + c.connection_type = self._bond_types[connection_type] + if isinstance(connection_type, AngleType): + c.connection_type = self._angle_types[connection_type] + if isinstance(connection_type, DihedralType): + if c.is_layered(): + c.connection_types.add( + self._dihedral_types[connection_type] + ) + else: + c.connection_type = self._dihedral_types[ + connection_type + ] + if isinstance(connection_type, ImproperType): + if c.is_layered(): + c.connection_types.add( + self._improper_types[connection_type] + ) + else: + c.connection_type = self._improper_types[ + connection_type + ] def add_pairpotentialtype(self, pairpotentialtype, update=True): """add a PairPotentialType to the topology @@ -746,10 +776,10 @@ def is_fully_typed(self, updated=False, group="topology"): angle.angle_type for angle in top._angles ), "dihedrals": lambda top: all( - dihedral.dihedral_type for dihedral in top._dihedrals + self._get_types(dihedral) for dihedral in top._dihedrals ), "impropers": lambda top: all( - improper.improper_type for improper in top._impropers + self._get_types(improper) for improper in top._impropers ), } @@ -833,7 +863,7 @@ def _get_untyped_dihedrals(self): "Return a list of untyped dihedrals" untyped = {"dihedrals": list()} for dihedral in self._dihedrals: - if not dihedral.dihedral_type: + if not self._get_types(dihedral): untyped["dihedrals"].append(dihedral) return untyped @@ -841,10 +871,23 @@ def _get_untyped_impropers(self): "Return a list of untyped impropers" untyped = {"impropers": list()} for improper in self._impropers: - if not improper.improper_type: + if not self._get_types(improper): untyped["impropers"].append(improper) return untyped + def _get_types(self, dihedral_or_improper): + """Get the dihedral/impropertypes for a dihedral/improper in this topology.""" + if not isinstance(dihedral_or_improper, (BaseDihedral, BaseImproper)): + raise TypeError( + f"Expected `dihedral_or_improper` to be either Dihedral or Improper. " + f"Got {type(dihedral_or_improper).__name__} instead." + ) + return ( + dihedral_or_improper.connection_types + if dihedral_or_improper.is_layered() + else dihedral_or_improper.connection_type + ) + def update_angle_types(self): """Use gmso.Topology.update_connection_types to update AngleTypes in the topology. diff --git a/gmso/tests/base_test.py b/gmso/tests/base_test.py index 68faa73e6..4b0562958 100644 --- a/gmso/tests/base_test.py +++ b/gmso/tests/base_test.py @@ -10,7 +10,8 @@ from gmso.core.atom_type import AtomType from gmso.core.bond import Bond from gmso.core.box import Box -from gmso.core.dihedral import Dihedral +from gmso.core.dihedral import Dihedral, LayeredDihedral +from gmso.core.dihedral_type import DihedralType from gmso.core.element import Hydrogen, Oxygen from gmso.core.forcefield import ForceField from gmso.core.improper import Improper @@ -18,6 +19,7 @@ from gmso.core.topology import Topology from gmso.external import from_mbuild, from_parmed from gmso.external.convert_foyer_xml import from_foyer_xml +from gmso.lib.potential_templates import PotentialTemplateLibrary from gmso.tests.utils import get_path from gmso.utils.io import get_fn, has_foyer @@ -501,6 +503,52 @@ def residue_top(self): return top + @pytest.fixture + def ld_top(self): + lib = PotentialTemplateLibrary() + rb_type = DihedralType.from_template( + potential_template=lib["RyckaertBellemansTorsionPotential"], + parameters={ + "c0": 9.28 * u.kJ / u.mol, + "c1": 12.16 * u.kJ / u.mol, + "c2": -13.12 * u.kJ / u.mol, + "c3": -3.06 * u.kJ / u.mol, + "c4": 26.24 * u.kJ / u.mol, + "c5": -31.5 * u.kJ / u.mol, + }, + ) + + periodic_type = DihedralType.from_template( + lib["PeriodicTorsionPotential"], + parameters={ + "k": 1.25 * u.nm, + "phi_eq": 3.14159 * u.rad, + "n": 2 * u.dimensionless, + }, + ) + + top = Topology(name="Topology") + atoms = [Atom(name=f"Atom{i + 1}") for i in range(0, 100)] + dihedrals_group = [ + (atoms[i], atoms[i + 1], atoms[i + 2], atoms[i + 3]) + for i in range(0, 100, 4) + ] + + for j, atom_groups in enumerate(dihedrals_group): + if j % 2 == 0: + dh = Dihedral( + connection_members=atom_groups, + ) + dh.connection_type = rb_type + + else: + dh = LayeredDihedral(connection_members=atom_groups) + dh.connection_types = [rb_type, periodic_type] + top.add_connection(connection=dh) + top.update_topology() + + return top + @pytest.fixture(scope="session") def pentane_ua_mbuild(self): class PentaneUA(mb.Compound): diff --git a/gmso/tests/test_dihedral.py b/gmso/tests/test_dihedrals.py similarity index 53% rename from gmso/tests/test_dihedral.py rename to gmso/tests/test_dihedrals.py index b2aefb721..fc4c4b3d7 100644 --- a/gmso/tests/test_dihedral.py +++ b/gmso/tests/test_dihedrals.py @@ -1,15 +1,16 @@ import pytest +import unyt as u from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType -from gmso.core.dihedral import Dihedral +from gmso.core.dihedral import Dihedral, LayeredDihedral from gmso.core.dihedral_type import DihedralType from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -class TestDihedral(BaseTest): +class TestDihedrals(BaseTest): def test_dihedral_nonparametrized(self): atom1 = Atom(name="atom1") atom2 = Atom(name="atom1") @@ -153,3 +154,134 @@ def test_equivalent_members_set(self): tuple(dihedral.connection_members) in dihedral_not_eq.equivalent_members() ) + + def test_layered_dihedrals(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + dihedral_type1 = DihedralType( + name=f"layer1", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 1 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + dihedral_type2 = DihedralType( + name=f"layer2", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 2 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + dihedral_type3 = DihedralType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + + connect = LayeredDihedral( + connection_members=[atom1, atom2, atom3, atom4], + dihedral_types=[dihedral_type1, dihedral_type2, dihedral_type3], + name="dihedral_name", + ) + + assert dihedral_type1 in connect.dihedral_types + assert dihedral_type2 in connect.dihedral_types + assert dihedral_type3 in connect.dihedral_types + + assert connect.dihedral_types[0].parameters["n"] == 1 + assert connect.dihedral_types[1].parameters["n"] == 2 + assert connect.dihedral_types[2].parameters["n"] == 3 + + def test_layered_dihedral_duplicate(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + dihedral_type1 = DihedralType( + name=f"layer1", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 1 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + dihedral_type2 = DihedralType( + name=f"layer2", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 2 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + dihedral_type3 = DihedralType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + + connect = LayeredDihedral( + connection_members=[atom1, atom2, atom3, atom4], + dihedral_types=[ + dihedral_type1, + dihedral_type2, + dihedral_type3, + dihedral_type3, + ], + name="dihedral_name", + ) + + assert len(connect.dihedral_types) == 3 + + def test_layered_dihedral_validation_error(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + dihedral_type = DihedralType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + with pytest.raises(ValidationError): + LayeredDihedral( + connection_members=[atom1, atom2, atom3, atom4], + dihedral_types_=["a1", "a2", "a3", "a4"], + name="dh1", + ) + + with pytest.raises(ValidationError): + LayeredDihedral( + connection_members=[atom1, atom2, atom3, atom4], + dihedral_types=dihedral_type, + name="dh1", + ) diff --git a/gmso/tests/test_improper.py b/gmso/tests/test_impropers.py similarity index 53% rename from gmso/tests/test_improper.py rename to gmso/tests/test_impropers.py index a853391d5..900535e4e 100644 --- a/gmso/tests/test_improper.py +++ b/gmso/tests/test_impropers.py @@ -1,15 +1,16 @@ import pytest +import unyt as u from pydantic import ValidationError from gmso.core.atom import Atom from gmso.core.atom_type import AtomType -from gmso.core.improper import Improper +from gmso.core.improper import Improper, LayeredImproper from gmso.core.improper_type import ImproperType from gmso.core.topology import Topology from gmso.tests.base_test import BaseTest -class TestImproper(BaseTest): +class TestImpropers(BaseTest): def test_improper_nonparametrized(self): atom1 = Atom(name="atom1") atom2 = Atom(name="atom2") @@ -152,3 +153,134 @@ def test_equivalent_members_set(self): tuple(improper.connection_members) in improper_not_eq.equivalent_members() ) + + def test_layered_impropers(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + improper_type1 = ImproperType( + name=f"layer1", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 1 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + improper_type2 = ImproperType( + name=f"layer2", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 2 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + improper_type3 = ImproperType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + + connect = LayeredImproper( + connection_members=[atom1, atom2, atom3, atom4], + improper_types=[improper_type1, improper_type2, improper_type3], + name="improper_name", + ) + + assert improper_type1 in connect.improper_types + assert improper_type2 in connect.improper_types + assert improper_type3 in connect.improper_types + + assert connect.improper_types[0].parameters["n"] == 1 + assert connect.improper_types[1].parameters["n"] == 2 + assert connect.improper_types[2].parameters["n"] == 3 + + def test_layered_improper_duplicate(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + improper_type1 = ImproperType( + name=f"layer1", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 1 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + improper_type2 = ImproperType( + name=f"layer2", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 2 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + improper_type3 = ImproperType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + + connect = LayeredImproper( + connection_members=[atom1, atom2, atom3, atom4], + improper_types=[ + improper_type1, + improper_type2, + improper_type3, + improper_type3, + ], + name="improper_name", + ) + + assert len(connect.improper_types) == 3 + + def test_layered_improper_validation_error(self): + atom1 = Atom(name="atom1") + atom2 = Atom(name="atom2") + atom3 = Atom(name="atom3") + atom4 = Atom(name="atom4") + + improper_type = ImproperType( + name=f"layer3", + expression="kn * (1 + cos(n * a - a0))", + independent_variables="a", + parameters={ + "kn": 1.0 * u.K * u.kb, + "n": 3 * u.dimensionless, + "a0": 30.0 * u.degree, + }, + ) + with pytest.raises(ValidationError): + LayeredImproper( + connection_members=[atom1, atom2, atom3, atom4], + improper_types_=["a1", "a2", "a3", "a4"], + name="dh1", + ) + + with pytest.raises(ValidationError): + LayeredImproper( + connection_members=[atom1, atom2, atom3, atom4], + improper_types=improper_type, + name="dh1", + ) diff --git a/gmso/tests/test_top.py b/gmso/tests/test_top.py index 518b7be72..389edbab5 100644 --- a/gmso/tests/test_top.py +++ b/gmso/tests/test_top.py @@ -74,7 +74,7 @@ def test_water_top(self, water_system): top.save("water.top") def test_ethane_periodic(self, typed_ethane): - from gmso.core.parametric_potential import ParametricPotential + from gmso.core.dihedral_type import DihedralType from gmso.lib.potential_templates import PotentialTemplateLibrary per_torsion = PotentialTemplateLibrary()["PeriodicTorsionPotential"] @@ -83,7 +83,7 @@ def test_ethane_periodic(self, typed_ethane): "phi_eq": 15 * u.Unit("degree"), "n": 3 * u.Unit("dimensionless"), } - periodic_dihedral_type = ParametricPotential.from_template( + periodic_dihedral_type = DihedralType.from_template( potential_template=per_torsion, parameters=params ) for dihedral in typed_ethane.dihedrals: diff --git a/gmso/tests/test_topology.py b/gmso/tests/test_topology.py index ceeb560e7..c592bf3d4 100644 --- a/gmso/tests/test_topology.py +++ b/gmso/tests/test_topology.py @@ -12,7 +12,7 @@ from gmso.core.bond import Bond from gmso.core.bond_type import BondType from gmso.core.box import Box -from gmso.core.dihedral import Dihedral +from gmso.core.dihedral import Dihedral, LayeredDihedral from gmso.core.dihedral_type import DihedralType from gmso.core.improper import Improper from gmso.core.improper_type import ImproperType @@ -760,6 +760,15 @@ def test_cget_untyped(self, typed_chloroethanol): with pytest.raises(ValueError): clone.get_untyped(group="foo") + def test_top_with_layered_dihedrals(self, ld_top): + assert len(ld_top.dihedral_types) == 2 + assert len(ld_top.dihedrals) == 25 + for j in range(ld_top.n_dihedrals): + if j % 2 == 0: + assert isinstance(ld_top._dihedrals[j], Dihedral) + else: + assert isinstance(ld_top._dihedrals[j], LayeredDihedral) + def test_iter_sites(self, residue_top): for site in residue_top.iter_sites("residue_name", "MY_RES_EVEN"): assert site.residue_name == "MY_RES_EVEN"