Skip to content

Commit

Permalink
Merge pull request #4 from ACEsuit/averageE0
Browse files Browse the repository at this point in the history
Average e0
  • Loading branch information
ilyes319 authored Jul 2, 2022
2 parents 51e8aca + 7d3cee8 commit f2c0b14
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ To give a specific validation set, use the argument `--valid_file`. To set a lar

To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`.

It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you should set `--model=ScaleShiftMACE` and pass a dictionary of 0's, e.g., `--E0s='{1:0.0, 6:0.0}'`.
It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression.

If the keyword `--swa` is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_swa` epochs). This setting usually helps lower the energy errors.

Expand Down
2 changes: 2 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .utils import (
Configuration,
Configurations,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
Expand All @@ -20,4 +21,5 @@
"config_from_atoms",
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
]
32 changes: 32 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import ase.io
import numpy as np

from mace.tools import AtomicNumberTable

Vector = np.ndarray # [3,]
Positions = np.ndarray # [..., 3]
Forces = np.ndarray # [..., 3]
Expand Down Expand Up @@ -163,3 +165,33 @@ def load_from_xyz(
forces_key=forces_key,
)
return atomic_energies_dict, configs


def compute_average_E0s(
collections_train: Configuration, z_table: AtomicNumberTable
) -> dict:
"""
Function to compute the average interaction energy of each chemical element
returns dictionary of E0s
"""
len_train = len(collections_train)
len_zs = len(z_table)
A = np.zeros((len_train, len_zs))
B = np.zeros(len_train)
for i in range(len_train):
B[i] = collections_train[i].energy
for j, z in enumerate(z_table.zs):
A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z)
try:
E0s = np.linalg.lstsq(A, B, rcond=None)[0]
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i]
except np.linalg.LinAlgError:
logging.warning(
"Failed to compute E0s using least squares regression, using the same for all atoms"
)
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0
return atomic_energies_dict
4 changes: 4 additions & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def train(
):
lowest_loss = np.inf
patience_counter = 0
swa_start = True

if max_grad_norm is not None:
logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}")
Expand Down Expand Up @@ -112,6 +113,9 @@ def train(
if swa is None or epoch < swa.start:
lr_scheduler.step(valid_loss) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on SWA")
swa_start = False
loss_fn = swa.loss_fn
swa.model.update_parameters(model)
swa.scheduler.step()
Expand Down
Empty file modified scripts/run_checks.sh
100644 → 100755
Empty file.
29 changes: 26 additions & 3 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_dataset_from_xyz(
# create list of tuples (config_type, list(Atoms))
test_configs = data.test_config_types(all_test_configs)
logging.info(
f"Loaded {len(all_test_configs)} test configurations from '{train_path}'"
f"Loaded {len(all_test_configs)} test configurations from '{test_path}'"
)
return (
SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs),
Expand All @@ -99,12 +99,21 @@ def main() -> None:
device = tools.init_device(args.device)
tools.set_default_dtype(args.default_dtype)

try:
config_type_weights = ast.literal_eval(args.config_type_weights)
assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703
logging.warning(
f"Config type weights not specified correctly ({e}), using Default"
)
config_type_weights = {"Default": 1.0}

# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
config_type_weights=ast.literal_eval(args.config_type_weights),
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
energy_key=args.energy_key,
Expand All @@ -131,7 +140,21 @@ def main() -> None:
logging.info(
"Atomic Energies not in training file, using command line argument E0s"
)
atomic_energies_dict = ast.literal_eval(args.E0s)
if args.E0s.lower() == "average":
logging.info(
"Computing average Atomic Energies using least squares regression"
)
atomic_energies_dict = data.compute_average_E0s(
collections.train, z_table
)
else:
try:
atomic_energies_dict = ast.literal_eval(args.E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
else:
raise RuntimeError(
"E0s not found in training file and not specified in command line"
Expand Down

0 comments on commit f2c0b14

Please sign in to comment.