From aeadfd5d5ef39476397edb633135517df8148338 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Sun, 2 Oct 2022 18:48:25 -0400 Subject: [PATCH 1/4] Initial attempt to add stress to MACECalculator. Changes (mainly fixing bugs in testing for compute_virials vs. compute_stress) are necessary, but not sufficient, since it's not working yet. Test is only an outline. --- mace/calculators/mace.py | 17 ++++--- mace/modules/models.py | 4 +- mace/modules/utils.py | 6 +-- tests/test_calculator.py | 95 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 10 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 7f90e705..aba7d043 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -15,7 +15,7 @@ class MACECalculator(Calculator): """MACE ASE Calculator""" - implemented_properties = ["energy", "forces"] + implemented_properties = ["energy", "forces", "stress"] def __init__( self, @@ -67,15 +67,18 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change batch = next(iter(data_loader)).to(self.device) # predict + extract data - out = self.model(batch) - forces = out["forces"].detach().cpu().numpy() + out = self.model(batch, compute_stress=True) energy = out["energy"].detach().cpu().item() + forces = out["forces"].detach().cpu().numpy() + stress = out["stress"].detach().cpu().numpy() # store results self.results = { "energy": energy * self.energy_units_to_eV, # force has units eng / len: "forces": forces * (self.energy_units_to_eV / self.length_units_to_A), + # force has units eng / len^3: + "stress": stress * (self.energy_units_to_eV / self.length_units_to_A**3), } @@ -154,6 +157,7 @@ class EnergyDipoleMACECalculator(Calculator): implemented_properties = [ "energy", "forces", + "stress", "dipole", ] @@ -212,9 +216,10 @@ def calculate(self, atoms=None, properties=["dipole"], system_changes=all_change batch = next(iter(data_loader)).to(self.device) # predict + extract data - out = self.model(batch) - forces = out["forces"].detach().cpu().numpy() + out = self.model(batch, compute_stress=True) energy = out["energy"].detach().cpu().item() + forces = out["forces"].detach().cpu().numpy() + stress = out["stress"].detach().cpu().numpy() dipole = out["dipole"].detach().cpu().numpy() # store results @@ -222,5 +227,7 @@ def calculate(self, atoms=None, properties=["dipole"], system_changes=all_change "energy": energy * self.energy_units_to_eV, # force has units eng / len: "forces": forces * (self.energy_units_to_eV / self.length_units_to_A), + # stress has units eng / len: + "stress": stress * (self.energy_units_to_eV / self.length_units_to_A**3), "dipole": dipole, } diff --git a/mace/modules/models.py b/mace/modules/models.py index cb05a770..9c99a137 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -153,7 +153,7 @@ def forward( # Setup data.positions.requires_grad = True displacement = None - if compute_virials: + if compute_virials or compute_stress: data.positions, data.shifts, displacement = get_symmetric_displacement( positions=data.positions, unit_shifts=data.unit_shifts, @@ -245,7 +245,7 @@ def forward( # Setup data.positions.requires_grad = True displacement = None - if compute_virials: + if compute_virials or compute_stress: data.positions, data.shifts, displacement = get_symmetric_displacement( positions=data.positions, unit_shifts=data.unit_shifts, diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 46bf3006..d1026314 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -130,7 +130,8 @@ def get_outputs( compute_virials: bool = True, compute_stress: bool = True, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - if compute_force and compute_virials: + if compute_virials or compute_stress: + # forces come for free forces, virials, stress = compute_forces_virials( energy=energy, positions=positions, @@ -139,13 +140,12 @@ def get_outputs( compute_stress=compute_stress, training=training, ) - elif compute_force and not compute_stress: + elif compute_force: forces, virials, stress = ( compute_forces(energy=energy, positions=positions, training=training), None, None, ) - stress = None else: forces, virials, stress = (None, None, None) return forces, virials, stress diff --git a/tests/test_calculator.py b/tests/test_calculator.py index e69de29b..6b2afc28 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -0,0 +1,95 @@ +import sys +import os +import subprocess + +import pytest + +from pathlib import Path +pytest_mace_dir = Path(__file__).parent.parent +run_train = Path(__file__).parent.parent / "scripts" / "run_train.py" + +import numpy as np + +import ase.io +from ase.atoms import Atoms +from ase.constraints import ExpCellFilter +from ase.calculators.test import gradient_test + +from mace.calculators.mace import MACECalculator + +water = Atoms(numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], cell=[4]*3, pbc = [True]*3) +fitting_configs = [Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6]*3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6]*3)] +fitting_configs[0].info["REF_energy"] = 0.0 +fitting_configs[0].info["config_type"] = "IsolatedAtom" +fitting_configs[1].info["REF_energy"] = 0.0 +fitting_configs[1].info["config_type"] = "IsolatedAtom" + +np.random.seed(5) +for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fitting_configs.append(c) + +@pytest.fixture +def trained_model(tmp_path): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress" + } + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ + run_env["PYTHONPATH"] = str(pytest_mace_dir) + ":" + os.environ["PYTHONPATH"] + + cmd = sys.executable + " " + str(run_train) + " " + " ".join([(f"--{k}={v}" if v is not None else f"--{k}") for k, v in mace_params.items()]) + + p = subprocess.run(cmd.split(), env=run_env) + + assert p.returncode == 0 + + return MACECalculator(tmp_path / "MACE.model", device="cpu") + + +def test_calculator(trained_model): + at = fitting_configs[0] + at.calc = trained_model + print("BOB", at.get_potential_energy()) + print("BOB", at.get_forces()) + print("BOB", at.get_stress()) + + # at_wrapped = ExpCellFilter(at) + # grad_qual = gradient_test(at_wrapped) + grad_qual = gradient_test(at) + + print("BOB", grad_qual) From 1d1b8a3e77626e61529a40a9f222980c02b3c8c4 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Sun, 2 Oct 2022 21:07:12 -0400 Subject: [PATCH 2/4] MACECalculator now working with stress, and passing gradient test. Required adding free_energy property, extracting stress from first element of 1x3x3 numpy array, and explicitly dealing with missing virials which can happen when pbc is False. --- mace/calculators/mace.py | 35 +++++++++++++------- mace/modules/utils.py | 2 +- tests/test_calculator.py | 69 ++++++++++++++++++++++++---------------- 3 files changed, 66 insertions(+), 40 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index aba7d043..d63da28e 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -15,7 +15,7 @@ class MACECalculator(Calculator): """MACE ASE Calculator""" - implemented_properties = ["energy", "forces", "stress"] + implemented_properties = ["energy", "free_energy", "forces", "stress"] def __init__( self, @@ -41,7 +41,7 @@ def __init__( torch_tools.set_default_dtype(default_dtype) # pylint: disable=dangerous-default-value - def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): + def calculate(self, atoms=None, properties=None, system_changes=all_changes): """ Calculate properties. :param atoms: ase.Atoms object @@ -70,17 +70,23 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change out = self.model(batch, compute_stress=True) energy = out["energy"].detach().cpu().item() forces = out["forces"].detach().cpu().numpy() - stress = out["stress"].detach().cpu().numpy() # store results + E = energy * self.energy_units_to_eV self.results = { - "energy": energy * self.energy_units_to_eV, + "energy": E, + "free_energy": E, # force has units eng / len: "forces": forces * (self.energy_units_to_eV / self.length_units_to_A), - # force has units eng / len^3: - "stress": stress * (self.energy_units_to_eV / self.length_units_to_A**3), } + # even though compute_stress is True, stress can be none if pbc is False + # not sure if correct ASE thing is to have no dict key, or dict key with value None + if out["stress"] is not None: + stress = out["stress"].detach().cpu().numpy() + # stress has units eng / len^3: + self.results["stress"] = (stress * (self.energy_units_to_eV / self.length_units_to_A**3))[0] + class DipoleMACECalculator(Calculator): """MACE ASE Calculator for predicting dipoles""" @@ -116,7 +122,7 @@ def __init__( torch_tools.set_default_dtype(default_dtype) # pylint: disable=dangerous-default-value - def calculate(self, atoms=None, properties=["dipole"], system_changes=all_changes): + def calculate(self, atoms=None, properties=None, system_changes=all_changes): """ Calculate properties. :param atoms: ase.Atoms object @@ -156,6 +162,7 @@ class EnergyDipoleMACECalculator(Calculator): implemented_properties = [ "energy", + "free_energy", "forces", "stress", "dipole", @@ -190,7 +197,7 @@ def __init__( torch_tools.set_default_dtype(default_dtype) # pylint: disable=dangerous-default-value - def calculate(self, atoms=None, properties=["dipole"], system_changes=all_changes): + def calculate(self, atoms=None, properties=None, system_changes=all_changes): """ Calculate properties. :param atoms: ase.Atoms object @@ -219,15 +226,21 @@ def calculate(self, atoms=None, properties=["dipole"], system_changes=all_change out = self.model(batch, compute_stress=True) energy = out["energy"].detach().cpu().item() forces = out["forces"].detach().cpu().numpy() - stress = out["stress"].detach().cpu().numpy() dipole = out["dipole"].detach().cpu().numpy() # store results + E = energy * self.energy_units_to_eV self.results = { - "energy": energy * self.energy_units_to_eV, + "energy": E, + "free_energy": E, # force has units eng / len: "forces": forces * (self.energy_units_to_eV / self.length_units_to_A), # stress has units eng / len: - "stress": stress * (self.energy_units_to_eV / self.length_units_to_A**3), "dipole": dipole, } + + # even though compute_stress is True, stress can be none if pbc is False + # not sure if correct ASE thing is to have no dict key, or dict key with value None + if out["stress"] is not None: + stress = out["stress"].detach().cpu().numpy() + self.results["stress"] = (stress * (self.energy_units_to_eV / self.length_units_to_A**3))[0] diff --git a/mace/modules/utils.py b/mace/modules/utils.py index d1026314..3e31e402 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -57,7 +57,7 @@ def compute_forces_virials( allow_unused=True, ) stress = None - if compute_stress: + if compute_stress and virials is not None: cell = cell.view(-1, 3, 3) volume = torch.einsum( "zi,zi->z", diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 6b2afc28..c7f38008 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -17,25 +17,29 @@ from mace.calculators.mace import MACECalculator -water = Atoms(numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], cell=[4]*3, pbc = [True]*3) -fitting_configs = [Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6]*3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6]*3)] -fitting_configs[0].info["REF_energy"] = 0.0 -fitting_configs[0].info["config_type"] = "IsolatedAtom" -fitting_configs[1].info["REF_energy"] = 0.0 -fitting_configs[1].info["config_type"] = "IsolatedAtom" - -np.random.seed(5) -for _ in range(20): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - fitting_configs.append(c) - -@pytest.fixture -def trained_model(tmp_path): +@pytest.fixture(scope="module") +def fitting_configs(): + water = Atoms(numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], cell=[4]*3, pbc = [True]*3) + fit_configs = [Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6]*3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6]*3)] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + +@pytest.fixture(scope="module") +def trained_model(tmp_path_factory, fitting_configs): _mace_params = { "name": "MACE", "valid_fraction": 0.05, @@ -61,6 +65,8 @@ def trained_model(tmp_path): "stress_key": "REF_stress" } + tmp_path = tmp_path_factory.mktemp("run_") + ase.io.write(tmp_path / "fit.xyz", fitting_configs) mace_params = _mace_params.copy() @@ -81,15 +87,22 @@ def trained_model(tmp_path): return MACECalculator(tmp_path / "MACE.model", device="cpu") -def test_calculator(trained_model): - at = fitting_configs[0] +def test_calculator_forces(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_stress(fitting_configs, trained_model): + at = fitting_configs[2].copy() at.calc = trained_model - print("BOB", at.get_potential_energy()) - print("BOB", at.get_forces()) - print("BOB", at.get_stress()) - # at_wrapped = ExpCellFilter(at) - # grad_qual = gradient_test(at_wrapped) - grad_qual = gradient_test(at) + # test forces and stress + at_wrapped = ExpCellFilter(at) + grads = gradient_test(at_wrapped) - print("BOB", grad_qual) + assert np.allclose(grads[0], grads[1]) From 1a04b2ee8457e3e6a09a55cf87800f64c4453ad2 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Mon, 17 Oct 2022 16:16:21 -0400 Subject: [PATCH 3/4] Fix bug (and general cleanup) in predicted virial shape when all atoms are outside cutoff --- mace/modules/utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 3e31e402..a12433ec 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -65,21 +65,14 @@ def compute_forces_virials( torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) stress = virials / volume.view(-1, 1, 1) - if forces is None and virials is None: + if forces is None: logging.warning("Gradient is None, padded with zeros") - return ( - torch.zeros_like(positions), - torch.zeros_like(positions).expand(1, 1, 3), - None, - ) - if forces is not None and virials is None: - logging.warning("Virial is None, padded with zeros") - return -1 * forces, torch.zeros_like(positions).expand(1, 1, 3), None - if forces is None and virials is not None: + forces = torch.zeros_like(positions) + if virials is None: logging.warning("Virial is None, padded with zeros") - return torch.zeros_like(positions), -1 * virials, None - return -1 * forces, -1 * virials, stress + virials = torch.zeros((1, 3, 3)) + return -1 * forces, -1 * virials, stress def get_symmetric_displacement( positions: torch.Tensor, From 3bbd5591ad33e7667a7d5b1b5954d706b6c2ad89 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 18 Oct 2022 19:35:40 +0200 Subject: [PATCH 4/4] linting and formatting --- mace/calculators/mace.py | 8 ++++-- mace/modules/utils.py | 1 + tests/test_calculator.py | 57 ++++++++++++++++++++++++++-------------- tests/test_run_train.py | 5 ++-- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index d63da28e..e7451005 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -85,7 +85,9 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if out["stress"] is not None: stress = out["stress"].detach().cpu().numpy() # stress has units eng / len^3: - self.results["stress"] = (stress * (self.energy_units_to_eV / self.length_units_to_A**3))[0] + self.results["stress"] = ( + stress * (self.energy_units_to_eV / self.length_units_to_A**3) + )[0] class DipoleMACECalculator(Calculator): @@ -243,4 +245,6 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): # not sure if correct ASE thing is to have no dict key, or dict key with value None if out["stress"] is not None: stress = out["stress"].detach().cpu().numpy() - self.results["stress"] = (stress * (self.energy_units_to_eV / self.length_units_to_A**3))[0] + self.results["stress"] = ( + stress * (self.energy_units_to_eV / self.length_units_to_A**3) + )[0] diff --git a/mace/modules/utils.py b/mace/modules/utils.py index a12433ec..94b5495c 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -74,6 +74,7 @@ def compute_forces_virials( return -1 * forces, -1 * virials, stress + def get_symmetric_displacement( positions: torch.Tensor, unit_shifts: torch.Tensor, diff --git a/tests/test_calculator.py b/tests/test_calculator.py index c7f38008..dc140494 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -1,27 +1,34 @@ -import sys import os import subprocess - -import pytest - +import sys from pathlib import Path -pytest_mace_dir = Path(__file__).parent.parent -run_train = Path(__file__).parent.parent / "scripts" / "run_train.py" -import numpy as np +import pytest import ase.io +import numpy as np from ase.atoms import Atoms -from ase.constraints import ExpCellFilter from ase.calculators.test import gradient_test +from ase.constraints import ExpCellFilter from mace.calculators.mace import MACECalculator -@pytest.fixture(scope="module") -def fitting_configs(): - water = Atoms(numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], cell=[4]*3, pbc = [True]*3) - fit_configs = [Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6]*3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6]*3)] +pytest_mace_dir = Path(__file__).parent.parent +run_train = Path(__file__).parent.parent / "scripts" / "run_train.py" + + +@pytest.fixture(scope="module", name="fitting_configs") +def fitting_configs_fixture(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] fit_configs[0].info["REF_energy"] = 0.0 fit_configs[0].info["config_type"] = "IsolatedAtom" fit_configs[1].info["REF_energy"] = 0.0 @@ -38,8 +45,9 @@ def fitting_configs(): return fit_configs -@pytest.fixture(scope="module") -def trained_model(tmp_path_factory, fitting_configs): + +@pytest.fixture(scope="module", name="trained_model") +def trained_model_fixture(tmp_path_factory, fitting_configs): _mace_params = { "name": "MACE", "valid_fraction": 0.05, @@ -62,7 +70,7 @@ def trained_model(tmp_path_factory, fitting_configs): "loss": "stress", "energy_key": "REF_energy", "forces_key": "REF_forces", - "stress_key": "REF_stress" + "stress_key": "REF_stress", } tmp_path = tmp_path_factory.mktemp("run_") @@ -78,9 +86,20 @@ def trained_model(tmp_path_factory, fitting_configs): run_env = os.environ run_env["PYTHONPATH"] = str(pytest_mace_dir) + ":" + os.environ["PYTHONPATH"] - cmd = sys.executable + " " + str(run_train) + " " + " ".join([(f"--{k}={v}" if v is not None else f"--{k}") for k, v in mace_params.items()]) - - p = subprocess.run(cmd.split(), env=run_env) + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 2188a56e..ec574d55 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -204,6 +204,7 @@ def test_run_train_missing_data(tmp_path, fitting_configs): ] assert np.allclose(Es, ref_Es) + def test_run_train_no_stress(tmp_path, fitting_configs): del fitting_configs[5].info["REF_energy"] del fitting_configs[6].arrays["REF_forces"] @@ -270,6 +271,6 @@ def test_run_train_no_stress(tmp_path, fitting_configs): -0.06608313576078294, -0.36358220540264646, -0.12097397940768086, - 0.002021055463491156 - ] + 0.002021055463491156, + ] assert np.allclose(Es, ref_Es)