From 3c5a77764fee2cbe48cc989dc7d614f95d178c7e Mon Sep 17 00:00:00 2001 From: "D.P. Kovacs" Date: Tue, 7 Feb 2023 10:34:46 +0000 Subject: [PATCH] debug and clean up evaluate code --- scripts/eval_configs.py | 90 +++++++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 26 deletions(-) diff --git a/scripts/eval_configs.py b/scripts/eval_configs.py index b2ac5e49..587fe2dc 100644 --- a/scripts/eval_configs.py +++ b/scripts/eval_configs.py @@ -11,15 +11,27 @@ import numpy as np import torch -from mace import data, tools -from mace.tools import torch_geometric +from mace import data +from mace.tools import torch_geometric, utils, torch_tools def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--configs", help="path to XYZ configurations", required=True) - parser.add_argument("--model", help="path to model", required=True) - parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--configs", + help="path to XYZ configurations", + required=True + ) + parser.add_argument( + "--model", + help="path to model", + required=True + ) + parser.add_argument( + "--output", + help="output path", + required=True + ) parser.add_argument( "--device", help="select device", @@ -34,33 +46,50 @@ def parse_args() -> argparse.Namespace: choices=["float32", "float64"], default="float64", ) - parser.add_argument("--batch_size", help="batch size", type=int, default=64) parser.add_argument( - "--no_contributions", - help="model does not output energy contributions ", + "--batch_size", + help="batch size", + type=int, + default=64 + ) + parser.add_argument( + "--compute_stress", + help="compute stress", action="store_true", - default=True, + default=False, + ) + parser.add_argument( + "--return_contributions", + help="model outputs energy contributions for each body order, only suppported for MACE, not ScaleShiftMACE", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", ) return parser.parse_args() def main(): args = parse_args() - tools.set_default_dtype(args.default_dtype) - device = tools.init_device(args.device) + torch_tools.set_default_dtype(args.default_dtype) + device = torch_tools.init_device(args.device) # Load model - model = torch.load(f=args.model, map_location=device) + model = torch.load(f=args.model, map_location=args.device) # Load data and prepare input - atoms_list = ase.io.read(args.configs, format="extxyz", index=":") + atoms_list = ase.io.read(args.configs, index=":") configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - z_table = tools.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=model.r_max) + data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max)) for config in configs ], batch_size=args.batch_size, @@ -71,18 +100,21 @@ def main(): # Collect data energies_list = [] contributions_list = [] + stresses_list = [] forces_collection = [] for batch in data_loader: batch = batch.to(device) - output = model(batch, training=False) - energies_list.append(tools.to_numpy(output["energy"])) + output = model(batch.to_dict(), compute_stress=args.compute_stress) + energies_list.append(torch_tools.to_numpy(output["energy"])) + if args.compute_stress: + stresses_list.append(torch_tools.to_numpy(output["stress"])) - if not args.no_contributions: - contributions_list.append(tools.to_numpy(output["contributions"])) + if args.return_contributions: + contributions_list.append(torch_tools.to_numpy(output["contributions"])) forces = np.split( - tools.to_numpy(output["forces"]), indices_or_sections=batch.ptr[1:], axis=0 + torch_tools.to_numpy(output["forces"]), indices_or_sections=batch.ptr[1:], axis=0 ) forces_collection.append(forces[:-1]) # drop last as its emtpy @@ -90,20 +122,26 @@ def main(): forces_list = [ forces for forces_list in forces_collection for forces in forces_list ] - assert len(atoms_list) == len(energies) == len(forces_list) + assert len(atoms_list) == len(energies) == len(forces_list) + if args.compute_stress: + stresses = np.concatenate(stresses_list, axis=0) + assert len(atoms_list) == stresses.shape[0] - if not args.no_contributions: + if args.return_contributions: contributions = np.concatenate(contributions_list, axis=0) assert len(atoms_list) == contributions.shape[0] # Store data in atoms objects for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): atoms.calc = None # crucial - atoms.info["energy"] = energy - atoms.arrays["forces"] = forces + atoms.info[args.info_prefix + "energy"] = energy + atoms.arrays[args.info_prefix + "forces"] = forces + + if args.compute_stress: + atoms.info[args.info_prefix + "stress"] = stresses[i] - if not args.no_contributions: - atoms.info["contributions"] = contributions[i] + if args.return_contributions: + atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] # Write atoms to output path ase.io.write(args.output, images=atoms_list, format="extxyz")