Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug and clean up evaluation code #72

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions scripts/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@
import numpy as np
import torch

from mace import data, tools
from mace.tools import torch_geometric
from mace import data
from mace.tools import torch_geometric, utils, torch_tools


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
parser.add_argument(
"--configs",
help="path to XYZ configurations",
required=True
)
parser.add_argument(
"--model",
help="path to model",
required=True
)
parser.add_argument(
"--output",
help="output path",
required=True
)
parser.add_argument(
"--device",
help="select device",
Expand All @@ -34,33 +46,50 @@ def parse_args() -> argparse.Namespace:
choices=["float32", "float64"],
default="float64",
)
parser.add_argument("--batch_size", help="batch size", type=int, default=64)
parser.add_argument(
"--no_contributions",
help="model does not output energy contributions ",
"--batch_size",
help="batch size",
type=int,
default=64
)
parser.add_argument(
"--compute_stress",
help="compute stress",
action="store_true",
default=True,
default=False,
)
parser.add_argument(
"--return_contributions",
help="model outputs energy contributions for each body order, only suppported for MACE, not ScaleShiftMACE",
action="store_true",
default=False,
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
return parser.parse_args()


def main():
args = parse_args()
tools.set_default_dtype(args.default_dtype)
device = tools.init_device(args.device)
torch_tools.set_default_dtype(args.default_dtype)
device = torch_tools.init_device(args.device)

# Load model
model = torch.load(f=args.model, map_location=device)
model = torch.load(f=args.model, map_location=args.device)

# Load data and prepare input
atoms_list = ase.io.read(args.configs, format="extxyz", index=":")
atoms_list = ase.io.read(args.configs, index=":")
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]

z_table = tools.AtomicNumberTable([int(z) for z in model.atomic_numbers])
z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(config, z_table=z_table, cutoff=model.r_max)
data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max))
for config in configs
],
batch_size=args.batch_size,
Expand All @@ -71,39 +100,48 @@ def main():
# Collect data
energies_list = []
contributions_list = []
stresses_list = []
forces_collection = []

for batch in data_loader:
batch = batch.to(device)
output = model(batch, training=False)
energies_list.append(tools.to_numpy(output["energy"]))
output = model(batch.to_dict(), compute_stress=args.compute_stress)
energies_list.append(torch_tools.to_numpy(output["energy"]))
if args.compute_stress:
stresses_list.append(torch_tools.to_numpy(output["stress"]))

if not args.no_contributions:
contributions_list.append(tools.to_numpy(output["contributions"]))
if args.return_contributions:
contributions_list.append(torch_tools.to_numpy(output["contributions"]))

forces = np.split(
tools.to_numpy(output["forces"]), indices_or_sections=batch.ptr[1:], axis=0
torch_tools.to_numpy(output["forces"]), indices_or_sections=batch.ptr[1:], axis=0
)
forces_collection.append(forces[:-1]) # drop last as its emtpy

energies = np.concatenate(energies_list, axis=0)
forces_list = [
forces for forces_list in forces_collection for forces in forces_list
]
assert len(atoms_list) == len(energies) == len(forces_list)
assert len(atoms_list) == len(energies) == len(forces_list)
if args.compute_stress:
stresses = np.concatenate(stresses_list, axis=0)
assert len(atoms_list) == stresses.shape[0]

if not args.no_contributions:
if args.return_contributions:
contributions = np.concatenate(contributions_list, axis=0)
assert len(atoms_list) == contributions.shape[0]

# Store data in atoms objects
for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)):
atoms.calc = None # crucial
atoms.info["energy"] = energy
atoms.arrays["forces"] = forces
atoms.info[args.info_prefix + "energy"] = energy
atoms.arrays[args.info_prefix + "forces"] = forces

if args.compute_stress:
atoms.info[args.info_prefix + "stress"] = stresses[i]

if not args.no_contributions:
atoms.info["contributions"] = contributions[i]
if args.return_contributions:
atoms.info[args.info_prefix + "BO_contributions"] = contributions[i]

# Write atoms to output path
ase.io.write(args.output, images=atoms_list, format="extxyz")
Expand Down