diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 82a57a47..48ea435d 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -21,7 +21,6 @@ ) from .loss import ( DipoleSingleLoss, - EnergyForcesLoss, WeightedEnergyForcesDipoleLoss, WeightedEnergyForcesLoss, WeightedEnergyForcesStressLoss, @@ -86,7 +85,6 @@ "ScaleShiftBOTNet", "AtomicDipolesMACE", "EnergyDipolesMACE", - "EnergyForcesLoss", "WeightedEnergyForcesLoss", "WeightedForcesLoss", "WeightedEnergyForcesVirialsLoss", diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 18fc3921..9f163bec 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -77,30 +77,6 @@ def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Te # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] -class EnergyForcesLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return self.energy_weight * mean_squared_error_energy( - ref, pred - ) + self.forces_weight * mean_squared_error_forces(ref, pred) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) - - class WeightedEnergyForcesLoss(torch.nn.Module): def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: super().__init__() diff --git a/scripts/eval_configs.py b/scripts/eval_configs.py index 587fe2dc..4f45aece 100644 --- a/scripts/eval_configs.py +++ b/scripts/eval_configs.py @@ -12,26 +12,14 @@ import torch from mace import data -from mace.tools import torch_geometric, utils, torch_tools +from mace.tools import torch_geometric, torch_tools, utils def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument( - "--configs", - help="path to XYZ configurations", - required=True - ) - parser.add_argument( - "--model", - help="path to model", - required=True - ) - parser.add_argument( - "--output", - help="output path", - required=True - ) + parser.add_argument("--configs", help="path to XYZ configurations", required=True) + parser.add_argument("--model", help="path to model", required=True) + parser.add_argument("--output", help="output path", required=True) parser.add_argument( "--device", help="select device", @@ -46,12 +34,7 @@ def parse_args() -> argparse.Namespace: choices=["float32", "float64"], default="float64", ) - parser.add_argument( - "--batch_size", - help="batch size", - type=int, - default=64 - ) + parser.add_argument("--batch_size", help="batch size", type=int, default=64) parser.add_argument( "--compute_stress", help="compute stress", @@ -89,7 +72,9 @@ def main(): data_loader = torch_geometric.dataloader.DataLoader( dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max)) + data.AtomicData.from_config( + config, z_table=z_table, cutoff=float(model.r_max) + ) for config in configs ], batch_size=args.batch_size, @@ -114,7 +99,9 @@ def main(): contributions_list.append(torch_tools.to_numpy(output["contributions"])) forces = np.split( - torch_tools.to_numpy(output["forces"]), indices_or_sections=batch.ptr[1:], axis=0 + torch_tools.to_numpy(output["forces"]), + indices_or_sections=batch.ptr[1:], + axis=0, ) forces_collection.append(forces[:-1]) # drop last as its emtpy @@ -122,7 +109,7 @@ def main(): forces_list = [ forces for forces_list in forces_collection for forces in forces_list ] - assert len(atoms_list) == len(energies) == len(forces_list) + assert len(atoms_list) == len(energies) == len(forces_list) if args.compute_stress: stresses = np.concatenate(stresses_list, axis=0) assert len(atoms_list) == stresses.shape[0] diff --git a/scripts/run_train.py b/scripts/run_train.py index 5554f121..9b8d50ef 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -5,10 +5,10 @@ ########################################################################################### import ast +import json import logging from pathlib import Path from typing import Optional -import json import numpy as np import torch.nn.functional @@ -178,9 +178,8 @@ def main() -> None: dipole_weight=args.dipole_weight, ) else: - loss_fn = modules.EnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) + # Unweighted Energy and Forces loss by default + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) logging.info(loss_fn) if args.compute_avg_num_neighbors: @@ -469,6 +468,7 @@ def main() -> None: if args.wandb: logging.info("Using Weights and Biases for logging") import wandb + wandb_config = {} args_dict = vars(args) args_dict_json = json.dumps(args_dict)