diff --git a/README.md b/README.md index e672713a..3dcf3e2e 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ python ./mace/scripts/run_train.py \ To give a specific validation set, use the argument `--valid_file`. To set a larger batch size for evaluating the validation set, specify `--valid_batch_size`. -To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. +To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys. It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. @@ -105,6 +105,53 @@ python3 ./mace/scripts/eval_configs.py \ You can run our [Colab tutorial](https://colab.research.google.com/drive/1D6EtMUjQPey_GkuxUAbPgld6_9ibIa-V?authuser=1#scrollTo=Z10787RE1N8T) to quickly get started with MACE. +## On-line data loading for large datasets + +If you have a large dataset that might not fit into the GPU memory it is recommended to preprocess the data on a CPU and use on-line dataloading for training the model. To preprocess your dataset specified as an xyz file run the `preprocess_data.py` script. An example is given here: + +```sh +mkdir processed_data +python ./mace/scripts/preprocess_data.py \ + --train_file="/path/to/train_large.xyz" \ + --valid_fraction=0.05 \ + --test_file="/path/to/test_large.xyz" \ + --atomic_numbers="[1, 6, 7, 8, 9, 15, 16, 17, 35, 53]" \ + --r_max=4.5 \ + --h5_prefix="processed_data/" \ + --compute_statistics \ + --E0s="average" \ + --num_workers=32 \ + --seed=123 \ +``` + +To see all options and a little description of them run `python ./mace/scripts/preprocess_data.py --help` . The script will create a number of HDF5 files in the `processed_data` folder which can be used for training. There wiull be one file for trainin, one for validation and a separate one for each `config_type` in the test set. To train the model use the `run_train.py` script as follows: + +```sh +python ./mace/scripts/run_train.py \ + --name="MACE_on_big_data" \ + --num_workers=16 \ + --train_file="./processed_data/train.h5" \ + --valid_file="./processed_data/valid.h5" \ + --test_dir="./processed_data" \ + --statistics_file="./processed_data/statistics.json" \ + --model="ScaleShiftMACE" \ + --num_interactions=2 \ + --num_channels=128 \ + --max_L=1 \ + --correlation=3 \ + --batch_size=32 \ + --valid_batch_size=32 \ + --max_num_epochs=100 \ + --swa \ + --start_swa=60 \ + --ema \ + --ema_decay=0.99 \ + --amsgrad \ + --error_table='PerAtomMAE' \ + --device=cuda \ + --seed=123 \ +``` + ## Weights and Biases for experiment tracking If you would like to use MACE with Weights and Biases to log your experiments simply install with diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 0d0c9bf2..88e79c77 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -9,7 +9,9 @@ load_from_xyz, random_train_valid_split, test_config_types, + save_dataset_as_HDF5, ) +from .hdf5_dataset import HDF5Dataset __all__ = [ "get_neighborhood", @@ -22,4 +24,6 @@ "config_from_atoms_list", "AtomicData", "compute_average_E0s", + "save_dataset_as_HDF5", + "HDF5Dataset", ] diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py new file mode 100644 index 00000000..107c5ddb --- /dev/null +++ b/mace/data/hdf5_dataset.py @@ -0,0 +1,59 @@ +import h5py +import torch +from torch.utils.data import Dataset +from mace.data import AtomicData + +class HDF5Dataset(Dataset): + def __init__(self, file, **kwargs): + super(HDF5Dataset, self).__init__() + # it might be dangerous to open the file here + # move opening of file to __getitem__? + self.file = h5py.File(file, 'r') + self.length = len(self.file.keys()) + # self.file = file + # self.length = len(h5py.File(file, 'r').keys()) + + def __len__(self): + return self.length + + def __getitem__(self, index): + # file = h5py.File(self.file, 'r') + # grp = file["config_" + str(index)] + grp = self.file["config_" + str(index)] + edge_index = grp['edge_index'][()] + positions = grp['positions'][()] + shifts = grp['shifts'][()] + unit_shifts = grp['unit_shifts'][()] + cell = grp['cell'][()] + node_attrs = grp['node_attrs'][()] + weight = grp['weight'][()] + energy_weight = grp['energy_weight'][()] + forces_weight = grp['forces_weight'][()] + stress_weight = grp['stress_weight'][()] + virials_weight = grp['virials_weight'][()] + forces = grp['forces'][()] + energy = grp['energy'][()] + stress = grp['stress'][()] + virials = grp['virials'][()] + dipole = grp['dipole'][()] + charges = grp['charges'][()] + return AtomicData( + edge_index = torch.tensor(edge_index, dtype=torch.long), + positions = torch.tensor(positions, dtype=torch.get_default_dtype()), + shifts = torch.tensor(shifts, dtype=torch.get_default_dtype()), + unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), + cell=torch.tensor(cell, dtype=torch.get_default_dtype()), + node_attrs=torch.tensor(node_attrs, dtype=torch.get_default_dtype()), + weight=torch.tensor(weight, dtype=torch.get_default_dtype()), + energy_weight=torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + forces_weight=torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + stress_weight=torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + virials_weight=torch.tensor(virials_weight, dtype=torch.get_default_dtype()), + forces=torch.tensor(forces, dtype=torch.get_default_dtype()), + energy=torch.tensor(energy, dtype=torch.get_default_dtype()), + stress=torch.tensor(stress, dtype=torch.get_default_dtype()), + virials=torch.tensor(virials, dtype=torch.get_default_dtype()), + dipole=torch.tensor(dipole, dtype=torch.get_default_dtype()), + charges=torch.tensor(charges, dtype=torch.get_default_dtype()), + ) + \ No newline at end of file diff --git a/mace/data/utils.py b/mace/data/utils.py index 710d5f8f..2999d0e5 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -5,6 +5,8 @@ ########################################################################################### import logging +import h5py +from multiprocessing import Pool from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple @@ -148,6 +150,9 @@ def config_from_atoms( if virials is None: virials = np.zeros((3, 3)) virials_weight = 0.0 + if dipole is None: + dipole = np.zeros(3) + #dipoles_weight = 0.0 return Configuration( atomic_numbers=atomic_numbers, @@ -265,3 +270,28 @@ def compute_average_E0s( for i, z in enumerate(z_table.zs): atomic_energies_dict[z] = 0.0 return atomic_energies_dict + +def save_dataset_as_HDF5( + dataset:List, out_name: str + ) -> None: + with h5py.File(out_name, 'w') as f: + for i, data in enumerate(dataset): + grp = f.create_group(f'config_{i}') + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 3fcd4537..82aeaa6b 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,4 +1,4 @@ -from .arg_parser import build_default_arg_parser +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .torch_tools import ( @@ -64,4 +64,5 @@ "cartesian_to_spherical", "voigt_to_matrix", "init_wandb", + "build_preprocess_arg_parser" ] diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index dd13656a..e0d09883 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -84,7 +84,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ], ) parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + "--r_max", help="distance cutoff (in Ang)", + type=float, + default=5.0 ) parser.add_argument( "--num_radial_basis", @@ -193,11 +195,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: # Dataset parser.add_argument( - "--train_file", help="Training set xyz file", type=str, required=True + "--train_file", help="Training set file, format is .xyz or .h5", type=str, + required=True, ) parser.add_argument( "--valid_file", - help="Validation set xyz file", + help="Validation set .xyz or .h5 file", default=None, type=str, required=False, @@ -211,8 +214,55 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--test_file", - help="Test set xyz file", + help="Test set .xyz pt .h5 file", + type=str, + ) + parser.add_argument( + "--test_dir", + help="Path to directory with test files named as test_*.h5", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--num_workers", + help="Number of workers for data loading", + type=int, + default=0, + ) + parser.add_argument( + "--pin_memory", + help="Pin memory for data loading", + default=True, + type=bool, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--mean", + help="Mean energy per atom of training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--std", + help="Standard deviation of force components in the training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--statistics_file", + help="json file containing statistics of training set", type=str, + default=None, + required=False, ) parser.add_argument( "--E0s", @@ -471,6 +521,143 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) return parser +def build_preprocess_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--train_file", + help="Training set h5 file", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Training set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--h5_prefix", + help="Prefix for h5 files when saving", + type=str, + default="", + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", + type=float, + default=5.0 + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default="energy", + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default="forces", + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default="virials", + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default="stress", + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default="dipole", + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default="charges", + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--compute_statistics", + help="Compute statistics for the dataset", + action="store_true", + default=False, + ) + parser.add_argument( + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, + default=16, + ) + + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--shuffle", + help="Shuffle the training dataset", + type=bool, + default=True, + ) + parser.add_argument( + "--num_workers", + help="Number of threads to use for data writing", + type=int, + default=1, + ) + parser.add_argument( + "--seed", + help="Random seed for splitting training and validation sets", + type=int, + default=123, + ) + return parser + def check_float_or_none(value: str) -> Optional[float]: try: diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index b5a2c992..208151ad 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -6,12 +6,14 @@ import dataclasses import logging +import ast +import os from typing import Dict, List, Optional, Tuple import torch from prettytable import PrettyTable -from mace import data +from mace import data, modules from mace.data import AtomicData from mace.tools import AtomicNumberTable, evaluate, torch_geometric @@ -97,13 +99,118 @@ def get_dataset_from_xyz( atomic_energies_dict, ) +def get_config_type_weights(ct_weights): + """ + Get config type weights from command line argument + """ + try: + config_type_weights = ast.literal_eval(ct_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + return config_type_weights + +def get_atomic_energies(E0s, train_collection, z_table)->dict: + if E0s is not None: + logging.info( + "Atomic Energies not in training file, using command line argument E0s" + ) + if E0s.lower() == "average": + logging.info( + "Computing average Atomic Energies using least squares regression" + ) + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_energies_dict = data.compute_average_E0s( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average E0s if no training xyz given, error {e} occured" + ) from e + else: + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) + return atomic_energies_dict + +def get_loss_fn(loss: str, + energy_weight: float, + forces_weight: float, + stress_weight: float, + virials_weight: float, + dipole_weight: float, + dipole_only: bool, + compute_dipole: bool) -> torch.nn.Module: + if loss == "weighted": + loss_fn = modules.WeightedEnergyForcesLoss( + energy_weight=energy_weight, forces_weight=forces_weight + ) + elif loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) + elif loss == "virials": + loss_fn = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + virials_weight=virials_weight, + ) + elif loss == "stress": + loss_fn = modules.WeightedEnergyForcesStressLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + ) + elif loss == "dipole": + assert ( + dipole_only is True + ), "dipole loss can only be used with AtomicDipolesMACE model" + loss_fn = modules.DipoleSingleLoss( + dipole_weight=dipole_weight, + ) + elif loss == "energy_forces_dipole": + assert dipole_only is False and compute_dipole is True + loss_fn = modules.WeightedEnergyForcesDipoleLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + dipole_weight=dipole_weight, + ) + else: + loss_fn = modules.EnergyForcesLoss( + energy_weight=energy_weight, forces_weight=forces_weight + ) + return loss_fn + +def get_files_with_suffix(dir_path:str, suffix:str)-> List[str]: + return [os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix)] + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == 'train': + return (0, key) + elif key == 'valid': + return (1, key) + else: + return (2, key) def create_error_table( table_type: str, - all_collections: list, - z_table: AtomicNumberTable, - r_max: float, - valid_batch_size: int, + all_data_loaders: dict, model: torch.nn.Module, loss_fn: torch.nn.Module, output_args: Dict[str, bool], @@ -170,17 +277,9 @@ def create_error_table( "RMSE MU / mDebye / atom", "rel MU RMSE %", ] - for name, subset in all_collections: - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - AtomicData.from_config(config, z_table=z_table, cutoff=r_max) - for config in subset - ], - batch_size=valid_batch_size, - shuffle=False, - drop_last=False, - ) - + + for name in sorted(all_data_loaders, key=custom_key): + data_loader = all_data_loaders[name] logging.info(f"Evaluating {name} ...") _, metrics = evaluate( model, diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py new file mode 100644 index 00000000..9519e929 --- /dev/null +++ b/scripts/preprocess_data.py @@ -0,0 +1,155 @@ +# This file loads an xyz dataset and prepares +# new hdf5 file that is ready for training with on-the-fly dataloading + +import logging +import ast +import numpy as np +import json +import random + +from ase.io import read +import torch + +from mace import tools, data +from mace.data.utils import save_dataset_as_HDF5 +from mace.tools.scripts_utils import (get_dataset_from_xyz, + get_atomic_energies) +from mace.tools import torch_geometric +from mace.modules import compute_avg_num_neighbors, scaling_classes + +def compute_statistics(train_loader: torch.utils.data.DataLoader, + scaling: str, + atomic_energies: np.ndarray): + """ + Compute the average number of neighbors and the mean energy and standard + deviation of the force components""" + avg_num_neighbors = compute_avg_num_neighbors(train_loader) + mean, std = scaling_classes[scaling](train_loader, atomic_energies) + return avg_num_neighbors, mean, std + +def main(): + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + + args = tools.build_preprocess_arg_parser().parse_args() + + # Setup + tools.set_seeds(args.seed) + random.seed(args.seed) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler()] + ) + + try: + config_type_weights = ast.literal_eval(args.config_type_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + + # Data preparation + collections, atomic_energies_dict = get_dataset_from_xyz( + train_path=args.train_file, + valid_path=args.valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=args.test_file, + seed=args.seed, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + ) + + # Atomic number table + # yapf: disable + if args.atomic_numbers is None: + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + logging.info("Using atomic numbers from command line argument") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) + + logging.info("Preparing training set") + training_set = [data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max) + for config in collections.train] + if args.shuffle: + random.shuffle(training_set) + + save_dataset_as_HDF5(training_set, + args.h5_prefix + "train.h5", + args.num_workers) + + if args.compute_statistics: + # Compute statistics + logging.info("Computing statistics") + if len(atomic_energies_dict) == 0: + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic energies: {atomic_energies.tolist()}") + train_loader = torch_geometric.dataloader.DataLoader( + training_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + avg_num_neighbors, mean, std = compute_statistics( + train_loader, args.scaling, atomic_energies + ) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": atomic_energies_dict, + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": z_table.zs, + "r_max": args.r_max, + } + with open(args.h5_prefix + "statistics.json", "w") as f: + json.dump(statistics, f) + + logging.info("Preparing validation set") + valid_set = [data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max) + for config in collections.valid] + if args.shuffle: + random.shuffle(valid_set) + + save_dataset_as_HDF5(valid_set, + args.h5_prefix + "valid.h5", + args.num_workers) + + if args.test_file is not None: + logging.info("Preparing test sets") + for name, subset in collections.tests: + test_set = [data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max) + for config in subset] + save_dataset_as_HDF5(test_set, + args.h5_prefix + name + "_test.h5", + args.num_workers) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/run_train.py b/scripts/run_train.py index b48af269..6dbca02c 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Optional import json +import os import numpy as np import torch.nn.functional @@ -19,8 +20,13 @@ import mace from mace import data, modules, tools from mace.tools import torch_geometric -from mace.tools.scripts_utils import create_error_table, get_dataset_from_xyz - +from mace.tools.scripts_utils import (create_error_table, + get_dataset_from_xyz, + get_atomic_energies, + get_config_type_weights, + get_loss_fn, + get_files_with_suffix) +from mace.data import HDF5Dataset def main() -> None: args = tools.build_default_arg_parser().parse_args() @@ -37,46 +43,78 @@ def main() -> None: device = tools.init_device(args.device) tools.set_default_dtype(args.default_dtype) - try: - config_type_weights = ast.literal_eval(args.config_type_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} + config_type_weights = get_config_type_weights(args.config_type_weights) + + if args.statistics_file is not None: + with open(args.statistics_file, "r") as f: + statistics = json.load(f) + logging.info("Using statistics json file") + args.r_max = statistics["r_max"] + args.atomic_numbers = str(statistics["atomic_numbers"]) + args.mean = statistics["mean"] + args.std = statistics["std"] + args.avg_num_neighbors = statistics["avg_num_neighbors"] + args.compute_avg_num_neighbors = False + parsed_atomic_energies_dict = statistics["atomic_energies"] + str_atomic_energies_dict = {} + for key, value in parsed_atomic_energies_dict.items(): + str_atomic_energies_dict[int(key)] = value + args.E0s = str(str_atomic_energies_dict) # Data preparation - collections, atomic_energies_dict = get_dataset_from_xyz( - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, - ) - - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" - ) + if args.train_file.endswith(".xyz"): + if args.valid_file is not None: + assert args.valid_file.endswith(".xyz"), "valid_file if given must be same format as train_file" + collections, atomic_energies_dict = get_dataset_from_xyz( + train_path=args.train_file, + valid_path=args.valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=args.test_file, + seed=args.seed, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + ) + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" + ) + elif args.train_file.endswith(".h5"): + atomic_energies_dict = None + else: + raise RuntimeError( + f"train_file must be either .xyz or .h5, got {args.train_file}" + ) + # Atomic number table # yapf: disable - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) + if args.atomic_numbers is None: + assert args.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + logging.info("Using atomic numbers from command line argument") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) # yapf: enable logging.info(z_table) + + if atomic_energies_dict is None or len(atomic_energies_dict) == 0: + if args.train_file.endswith(".xyz"): + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + else: + atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table) + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True @@ -96,91 +134,63 @@ def main() -> None: else: compute_energy = True compute_dipole = False - if atomic_energies_dict is None or len(atomic_energies_dict) == 0: - if args.E0s is not None: - logging.info( - "Atomic Energies not in training file, using command line argument E0s" - ) - if args.E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - atomic_energies_dict = data.compute_average_E0s( - collections.train, z_table - ) - else: - try: - atomic_energies_dict = ast.literal_eval(args.E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occured" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) + atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) logging.info(f"Atomic energies: {atomic_energies.tolist()}") - train_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.train - ], - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - ) - valid_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.valid - ], - batch_size=args.valid_batch_size, - shuffle=False, - drop_last=False, - ) - - loss_fn: torch.nn.Module - if args.loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) - elif args.loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) - elif args.loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - virials_weight=args.virials_weight, - ) - elif args.loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - ) - elif args.loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=args.dipole_weight, + if args.train_file.endswith(".xyz"): + # TODO remove code duplication here + train_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + for config in collections.train + ], + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.num_workers, ) - elif args.loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - dipole_weight=args.dipole_weight, + valid_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + for config in collections.valid + ], + batch_size=args.valid_batch_size, + shuffle=False, + drop_last=False, + num_workers=args.num_workers, ) else: - loss_fn = modules.EnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) + training_set_processed = HDF5Dataset(args.train_file) + train_loader = torch_geometric.dataloader.DataLoader( + training_set_processed, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.num_workers, + pin_memory=args.pin_memory) + + validation_set_processed = HDF5Dataset(args.valid_file) + valid_loader = torch_geometric.dataloader.DataLoader( + validation_set_processed, + batch_size=args.valid_batch_size, + shuffle=False, + drop_last=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory) + + loss_fn: torch.nn.Module = get_loss_fn( + args.loss, + args.energy_weight, + args.forces_weight, + args.stress_weight, + args.virials_weight, + args.dipole_weight, + dipole_only, + compute_dipole, + ) logging.info(loss_fn) if args.compute_avg_num_neighbors: @@ -218,6 +228,7 @@ def main() -> None: if args.max_L > 2: args.hidden_irreps += f" + {args.num_channels:d}x3o" logging.info(f"Hidden irreps: {args.hidden_irreps}") + model_config = dict( r_max=args.r_max, num_bessel=args.num_radial_basis, @@ -234,14 +245,13 @@ def main() -> None: model: torch.nn.Module + if args.scaling == "no_scaling": + args.std = 1.0 + logging.info("No scaling selected") + elif args.mean is None or args.std is None: + args.mean, args.std = modules.scaling_classes[args.scaling](train_loader, atomic_energies) + if args.model == "MACE": - if args.scaling == "no_scaling": - std = 1.0 - logging.info("No scaling selected") - else: - mean, std = modules.scaling_classes[args.scaling]( - train_loader, atomic_energies - ) model = modules.ScaleShiftMACE( **model_config, correlation=args.correlation, @@ -250,29 +260,27 @@ def main() -> None: "RealAgnosticInteractionBlock" ], MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=std, + atomic_inter_scale=args.std, atomic_inter_shift=0.0, ) elif args.model == "ScaleShiftMACE": - mean, std = modules.scaling_classes[args.scaling](train_loader, atomic_energies) model = modules.ScaleShiftMACE( **model_config, correlation=args.correlation, gate=modules.gate_dict[args.gate], interaction_cls_first=modules.interaction_classes[args.interaction_first], MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=std, - atomic_inter_shift=mean, + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, ) elif args.model == "ScaleShiftBOTNet": - mean, std = modules.scaling_classes[args.scaling](train_loader, atomic_energies) model = modules.ScaleShiftBOTNet( **model_config, gate=modules.gate_dict[args.gate], interaction_cls_first=modules.interaction_classes[args.interaction_first], MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=std, - atomic_inter_shift=mean, + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, ) elif args.model == "BOTNet": model = modules.BOTNet( @@ -508,13 +516,40 @@ def main() -> None: log_wandb=args.wandb, ) - # Evaluation on test datasets logging.info("Computing metrics for training, validation, and test sets") - - all_collections = [ - ("train", collections.train), - ("valid", collections.valid), - ] + collections.tests + all_data_loaders = { + "train": train_loader, + "valid": valid_loader, + } + if args.train_file.endswith(".xyz"): + for name, subset in collections.tests: + test_set = [data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max) + for config in subset] + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=False, + ) + all_data_loaders[name] = test_loader + else: + # get all test paths + test_files = get_files_with_suffix( + args.test_dir, "_test.h5" + ) + for test_file in test_files: + test_set = HDF5Dataset(test_file) + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=False, + drop_last=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory) + test_file_name = os.path.splitext(os.path.basename(test_file))[0] + all_data_loaders[test_file_name] = test_loader for swa_eval in swas: epoch = checkpoint_handler.load_latest( @@ -527,10 +562,7 @@ def main() -> None: table = create_error_table( table_type=args.error_table, - all_collections=all_collections, - z_table=z_table, - r_max=args.r_max, - valid_batch_size=args.valid_batch_size, + all_data_loaders=all_data_loaders, model=model, loss_fn=loss_fn, output_args=output_args, @@ -556,6 +588,5 @@ def main() -> None: logging.info("Done") - if __name__ == "__main__": main()