-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
36 changed files
with
3,138 additions
and
405 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,6 @@ build/ | |
|
||
# IDE | ||
.idea/ | ||
.vscode/ | ||
.vscode/ | ||
logs/MACE_run-5.log | ||
*.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,12 @@ conda create --name mace_env | |
conda activate mace_env | ||
|
||
# Install PyTorch | ||
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c conda-forge | ||
conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia | ||
|
||
# Clone and install MACE (and all required packages), use token if still private repo | ||
# (optional) Install MACE's dependencies from Conda as well | ||
conda install numpy scipy matplotlib ase opt_einsum prettytable pandas e3nn | ||
|
||
# Clone and install MACE (and all required packages) | ||
git clone [email protected]:ACEsuit/mace.git | ||
pip install ./mace | ||
``` | ||
|
@@ -85,6 +88,10 @@ The precision can be changed using the keyword ``--default_dtype``, the default | |
|
||
The keywords ``--batch_size`` and ``--max_num_epochs`` should be adapted based on the size of the training set. The batch size should be increased when the number of training data increases, and the number of epochs should be decreased. An heuristic for initial settings, is to consider the number of gradient update constant to 200 000, which can be computed as $\text{max-num-epochs}*\frac{\text{num-configs-training}}{\text{batch-size}}$. | ||
|
||
The code can handle training set with heterogeneous labels, for example containing both bulk structures with stress and isolated molecules. In this example, to make the code ignore stress on molecules, append to your molecules configuration a ``config_stress_weight = 0.0``. | ||
|
||
To use Apple Silicon GPU acceleration make sure to install the latest PyTorch version and specify ``--device=mps``. | ||
|
||
### Evaluation | ||
|
||
To evaluate your MACE model on an XYZ file, run the `eval_configs.py`: | ||
|
@@ -100,6 +107,16 @@ python3 ./mace/scripts/eval_configs.py \ | |
|
||
You can run our [Colab tutorial](https://colab.research.google.com/drive/1D6EtMUjQPey_GkuxUAbPgld6_9ibIa-V?authuser=1#scrollTo=Z10787RE1N8T) to quickly get started with MACE. | ||
|
||
## Weights and Biases for experiment tracking | ||
|
||
If you would like to use MACE with Weights and Biases to log your experiments simply install with | ||
|
||
```sh | ||
pip install ./mace[wandb] | ||
``` | ||
|
||
And specify the necessary keyword arguments (`--wandb`, `--wandb_project`, `--wandb_entity`, `--wandb_name`, `--wandb_log_hypers`) | ||
|
||
## Development | ||
|
||
We use `black`, `isort`, `pylint`, and `mypy`. | ||
|
@@ -126,7 +143,6 @@ editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, | |
year={2022}, | ||
url={https://openreview.net/forum?id=YPpSngE-ZU} | ||
} | ||
@misc{Batatia2022Design, | ||
title = {The Design Space of E(3)-Equivariant Atom-Centered Interatomic Potentials}, | ||
author = {Batatia, Ilyes and Batzner, Simon and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Musaelian, Albert and Simm, Gregor N. C. and Drautz, Ralf and Ortner, Christoph and Kozinsky, Boris and Cs{\'a}nyi, G{\'a}bor}, | ||
|
@@ -147,4 +163,4 @@ For bugs or feature requests, please use [GitHub Issues](https://github.com/ACEs | |
|
||
## License | ||
|
||
MACE is published and distributed under the [MIT license](MIT.md). | ||
MACE is published and distributed under the [Academic Software License v1.0 ](ASL.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
from .mace import MACECalculator | ||
from .lammps_mace import LAMMPS_MACE | ||
from .mace import DipoleMACECalculator, EnergyDipoleMACECalculator, MACECalculator | ||
|
||
__all__ = ["MACECalculator"] | ||
__all__ = [ | ||
"MACECalculator", | ||
"DipoleMACECalculator", | ||
"EnergyDipoleMACECalculator", | ||
"LAMMPS_MACE", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import torch | ||
from e3nn.util.jit import compile_mode | ||
|
||
from mace.modules.utils import get_outputs | ||
from mace.tools.scatter import scatter_sum | ||
|
||
|
||
@compile_mode("script") | ||
class LAMMPS_MACE(torch.nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
self.register_buffer("atomic_numbers", model.atomic_numbers) | ||
self.register_buffer("r_max", model.r_max) | ||
self.register_buffer("num_interactions", model.num_interactions) | ||
|
||
def forward( | ||
self, | ||
data: Dict[str, torch.Tensor], | ||
mask_ghost: torch.Tensor, | ||
compute_force: bool = True, | ||
compute_virials: bool = False, | ||
compute_stress: bool = False, | ||
) -> Dict[str, Optional[torch.Tensor]]: | ||
num_graphs = data["ptr"].numel() - 1 | ||
compute_displacement = False | ||
if compute_virials or compute_stress: | ||
compute_displacement = True | ||
|
||
out = self.model( | ||
data, | ||
training=False, | ||
compute_force=False, | ||
compute_virials=False, | ||
compute_stress=False, | ||
compute_displacement=compute_displacement, | ||
) | ||
node_energy = out["node_energy"] | ||
if node_energy is None: | ||
return {"energy": None, "forces": None, "virials": None, "stress": None} | ||
displacement = out["displacement"] | ||
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) | ||
stress: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) | ||
if mask_ghost is not None and displacement is not None: | ||
# displacement.requires_grad_(True) # For some reason torchscript needs that. | ||
node_energy_ghost = node_energy * mask_ghost | ||
total_energy_ghost = scatter_sum( | ||
src=node_energy_ghost, index=data["batch"], dim=-1, dim_size=num_graphs | ||
) | ||
grad_outputs: List[Optional[torch.Tensor]] = [ | ||
torch.ones_like(total_energy_ghost) | ||
] | ||
virials = torch.autograd.grad( | ||
outputs=[total_energy_ghost], | ||
inputs=[displacement], | ||
grad_outputs=grad_outputs, | ||
retain_graph=True, | ||
create_graph=True, | ||
allow_unused=True, | ||
)[0] | ||
|
||
if virials is not None: | ||
virials = -1 * virials | ||
cell = data["cell"].view(-1, 3, 3) | ||
volume = torch.einsum( | ||
"zi,zi->z", | ||
cell[:, 0, :], | ||
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), | ||
).unsqueeze(-1) | ||
stress = virials / volume.view(-1, 1, 1) | ||
else: | ||
virials = torch.zeros_like(displacement) | ||
|
||
total_energy = scatter_sum( | ||
src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs | ||
) | ||
|
||
forces, _, _ = get_outputs( | ||
energy=total_energy, | ||
positions=data["positions"], | ||
displacement=displacement, | ||
cell=data["cell"], | ||
training=False, | ||
compute_force=compute_force, | ||
compute_virials=False, | ||
compute_stress=False, | ||
) | ||
|
||
return { | ||
"energy": total_energy, | ||
"node_energy": node_energy, | ||
"forces": forces, | ||
"virials": virials, | ||
"stress": stress, | ||
} |
Oops, something went wrong.