Skip to content

Commit

Permalink
Merge branch 'master' into lukas/more-optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasturcani authored Jul 31, 2024
2 parents e2d21e1 + 96ca837 commit 1b44392
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/calculators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Calculators
OrcaEnergy <_autosummary/stko.OrcaEnergy>
RmsdCalculator <_autosummary/stko.RmsdCalculator>
RmsdMappedCalculator <_autosummary/stko.RmsdMappedCalculator>
KabschRmsdCalculator <_autosummary/stko.KabschRmsdCalculator>
ShapeCalculator <_autosummary/stko.ShapeCalculator>
TorsionCalculator <_autosummary/stko.TorsionCalculator>
ConstructedMoleculeTorsionCalculator <_autosummary/stko.ConstructedMoleculeTorsionCalculator>
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"rdkit==2023.9.5", # remove pin when type issues are resolved
"stk",
"networkx",
"rmsd",
]
requires-python = ">=3.11"
dynamic = ["version"]
Expand Down Expand Up @@ -133,5 +134,6 @@ module = [
"openff.*",
"openmm.*",
"openmmforcefields.*",
"rmsd.*",
]
ignore_missing_imports = true
2 changes: 2 additions & 0 deletions src/stko/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from stko._internal.calculators.results.xtb_results import XTBResults
from stko._internal.calculators.rmsd_calculators import (
KabschRmsdCalculator,
RmsdCalculator,
RmsdMappedCalculator,
)
Expand Down Expand Up @@ -134,6 +135,7 @@
"XTBResults",
"RmsdCalculator",
"RmsdMappedCalculator",
"KabschRmsdCalculator",
"ShapeCalculator",
"OrcaResults",
"PlanarityResults",
Expand Down
59 changes: 59 additions & 0 deletions src/stko/_internal/calculators/rmsd_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import stk
from rmsd import kabsch_rmsd
from scipy.spatial.distance import cdist

from stko._internal.calculators.results.rmsd_results import RmsdResults
Expand Down Expand Up @@ -221,3 +222,61 @@ def calculate(self, mol: stk.Molecule) -> float:
)
mol = mol.with_centroid(np.array((0, 0, 0)))
return self._calculate_rmsd(mol)


class KabschRmsdCalculator:
"""Calculates the root mean square distance between molecules.
This calculator uses the rmsd package with the default settings and no
reordering.
See Also:
* rmsd https://github.com/charnley/rmsd
This calculator will only work if the two molecules are the same
and have the same atom ordering.
Parameters:
initial_molecule:
The :class:`stk.Molecule` to calculate RMSD from.
Examples:
.. code-block:: python
import stk
import stko
bb1 = stk.BuildingBlock('C1CCCCC1')
calculator = stko.KabschRmsdCalculator(bb1)
results = calculator.get_results(stk.UFF().optimize(bb1))
rmsd = results.get_rmsd()
"""

def __init__(self, initial_molecule: stk.Molecule) -> None:
self._initial_molecule = initial_molecule

def _calculate_rmsd(self, mol: stk.Molecule) -> float:
p_coord = self._initial_molecule.get_position_matrix()
q_coord = mol.get_position_matrix()
return kabsch_rmsd(p_coord, q_coord)

def calculate(self, mol: stk.Molecule) -> float:
self._initial_molecule = self._initial_molecule.with_centroid(
position=np.array((0, 0, 0)),
)
mol = mol.with_centroid(np.array((0, 0, 0)))
return self._calculate_rmsd(mol)

def get_results(self, mol: stk.Molecule) -> RmsdResults:
"""Calculate the RMSD between `mol` and the initial molecule.
Parameters:
mol:
The :class:`stk.Molecule` to calculate RMSD to.
Returns:
The RMSD between the molecules.
"""
return RmsdResults(self.calculate(mol))
26 changes: 26 additions & 0 deletions tests/calculators/rmsd/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class CaseData:
mol1: stk.Molecule
mol2: stk.Molecule
rmsd: float
kabsch_rmsd: float


_optimizer = stko.UFF()
Expand Down Expand Up @@ -43,6 +44,7 @@ class CaseData:
mol1=_cc_molecule,
mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -52,6 +54,7 @@ class CaseData:
)
),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -61,31 +64,37 @@ class CaseData:
)
),
rmsd=1.0,
kabsch_rmsd=1.0,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")),
rmsd=0.24492870054279647,
kabsch_rmsd=0.188295954166067,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")),
rmsd=0.35636491354918015,
kabsch_rmsd=0.35044001253075,
),
CaseData(
mol1=stk.BuildingBlock("c1ccccc1"),
mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")),
rmsd=0.02936762392637932,
kabsch_rmsd=0.02936762392637932,
),
CaseData(
mol1=_polymer,
mol2=_optimizer.optimize(_polymer),
rmsd=2.1485735050384,
kabsch_rmsd=1.786251608496134,
),
],
)
Expand All @@ -101,26 +110,31 @@ def case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=stk.BuildingBlock("NCCN"),
mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")),
rmsd=0.20811702035676308,
kabsch_rmsd=0.20811702035676308,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")),
rmsd=0.22563756374632568,
kabsch_rmsd=0.22563756374632568,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("c1ccccc1"),
mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")),
rmsd=0.029156836455717483,
kabsch_rmsd=0.029156836455717483,
),
CaseData(
mol1=_polymer,
mol2=_optimizer.optimize(_polymer),
rmsd=1.792856412415046,
kabsch_rmsd=1.792856412415046,
),
],
)
Expand All @@ -136,11 +150,13 @@ def ignore_h_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("c1ccccc1"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
],
)
Expand All @@ -156,6 +172,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=_polymer,
mol2=_polymer.with_canonical_atom_ordering(),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock(
Expand All @@ -171,6 +188,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData:
),
).with_canonical_atom_ordering(),
rmsd=0.0,
kabsch_rmsd=0.0,
),
],
)
Expand All @@ -186,6 +204,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=_cc_molecule,
mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -195,6 +214,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -204,6 +224,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
),
rmsd=1.0,
kabsch_rmsd=1.0,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
Expand All @@ -215,6 +236,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
.with_displacement(np.array((2, 0, 1))),
rmsd=1.1309858484314543,
kabsch_rmsd=1.1309858484314543,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
Expand All @@ -226,21 +248,25 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
.with_displacement(np.array((0, 0, 1))),
rmsd=0.5943193981905652,
kabsch_rmsd=0.5943193981905652,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("NCCCN"),
rmsd=0.8832914099448816,
kabsch_rmsd=0.8832914099448816,
),
CaseData(
mol1=stk.BuildingBlock("NCOCN"),
mol2=stk.BuildingBlock("NCCN"),
rmsd=1.2678595995702466,
kabsch_rmsd=1.2678595995702466,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("NCOCN"),
rmsd=1.3921770318522637,
kabsch_rmsd=1.3921770318522637,
),
],
)
Expand Down
7 changes: 7 additions & 0 deletions tests/calculators/rmsd/test_rmsd_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ def test_rmsd(case_data: CaseData) -> None:
assert np.isclose(test_rmsd, case_data.rmsd, atol=1e-4)


def test_kabsch_rmsd(case_data: CaseData) -> None:
calculator = stko.KabschRmsdCalculator(case_data.mol1)
results = calculator.get_results(case_data.mol2)
test_rmsd = results.get_rmsd()
assert np.isclose(test_rmsd, case_data.kabsch_rmsd, atol=1e-4)


def test_rmsd_ignore_hydrogens(ignore_h_case_data: CaseData) -> None:
calculator = stko.RmsdCalculator(
ignore_h_case_data.mol1, ignore_hydrogens=True
Expand Down

0 comments on commit 1b44392

Please sign in to comment.