Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Average e0 #4

Merged
merged 7 commits into from
Jul 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ To give a specific validation set, use the argument `--valid_file`. To set a lar

To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`.

It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you should set `--model=ScaleShiftMACE` and pass a dictionary of 0's, e.g., `--E0s='{1:0.0, 6:0.0}'`.
It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression.

If the keyword `--swa` is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_swa` epochs). This setting usually helps lower the energy errors.

Expand Down
2 changes: 2 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .utils import (
Configuration,
Configurations,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
load_from_xyz,
Expand All @@ -20,4 +21,5 @@
"config_from_atoms",
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
]
32 changes: 32 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import ase.io
import numpy as np

from mace.tools import AtomicNumberTable

Vector = np.ndarray # [3,]
Positions = np.ndarray # [..., 3]
Forces = np.ndarray # [..., 3]
Expand Down Expand Up @@ -163,3 +165,33 @@ def load_from_xyz(
forces_key=forces_key,
)
return atomic_energies_dict, configs


def compute_average_E0s(
collections_train: Configuration, z_table: AtomicNumberTable
) -> dict:
"""
Function to compute the average interaction energy of each chemical element
returns dictionary of E0s
"""
len_train = len(collections_train)
len_zs = len(z_table)
A = np.zeros((len_train, len_zs))
B = np.zeros(len_train)
for i in range(len_train):
B[i] = collections_train[i].energy
for j, z in enumerate(z_table.zs):
A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z)
try:
E0s = np.linalg.lstsq(A, B, rcond=None)[0]
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i]
except np.linalg.LinAlgError:
logging.warning(
"Failed to compute E0s using least squares regression, using the same for all atoms"
)
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0
return atomic_energies_dict
4 changes: 4 additions & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def train(
):
lowest_loss = np.inf
patience_counter = 0
swa_start = True

if max_grad_norm is not None:
logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}")
Expand Down Expand Up @@ -112,6 +113,9 @@ def train(
if swa is None or epoch < swa.start:
lr_scheduler.step(valid_loss) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on SWA")
swa_start = False
loss_fn = swa.loss_fn
swa.model.update_parameters(model)
swa.scheduler.step()
Expand Down
Empty file modified scripts/run_checks.sh
100644 → 100755
Empty file.
29 changes: 26 additions & 3 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_dataset_from_xyz(
# create list of tuples (config_type, list(Atoms))
test_configs = data.test_config_types(all_test_configs)
logging.info(
f"Loaded {len(all_test_configs)} test configurations from '{train_path}'"
f"Loaded {len(all_test_configs)} test configurations from '{test_path}'"
)
return (
SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs),
Expand Down Expand Up @@ -107,12 +107,21 @@ def main() -> None:
device = tools.init_device(args.device)
tools.set_default_dtype(args.default_dtype)

try:
config_type_weights = ast.literal_eval(args.config_type_weights)
assert isinstance(config_type_weights, dict)
except Exception as e: # pylint: disable=W0703
logging.warning(
f"Config type weights not specified correctly ({e}), using Default"
)
config_type_weights = {"Default": 1.0}

# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
config_type_weights=ast.literal_eval(args.config_type_weights),
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
energy_key=args.energy_key,
Expand All @@ -139,7 +148,21 @@ def main() -> None:
logging.info(
"Atomic Energies not in training file, using command line argument E0s"
)
atomic_energies_dict = ast.literal_eval(args.E0s)
if args.E0s.lower() == "average":
logging.info(
"Computing average Atomic Energies using least squares regression"
)
atomic_energies_dict = data.compute_average_E0s(
collections.train, z_table
)
else:
try:
atomic_energies_dict = ast.literal_eval(args.E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
else:
raise RuntimeError(
"E0s not found in training file and not specified in command line"
Expand Down