Skip to content

Commit

Permalink
Merge pull request #100 from chaitjo/patch-1
Browse files Browse the repository at this point in the history
Fix unweighted EnergyForcesLoss
  • Loading branch information
ilyes319 authored May 24, 2023
2 parents f6d48af + 3655a91 commit 538b03c
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 29 deletions.
2 changes: 0 additions & 2 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from .loss import (
DipoleSingleLoss,
EnergyForcesLoss,
WeightedEnergyForcesDipoleLoss,
WeightedEnergyForcesLoss,
WeightedEnergyForcesStressLoss,
Expand Down Expand Up @@ -87,7 +86,6 @@
"ScaleShiftBOTNet",
"AtomicDipolesMACE",
"EnergyDipolesMACE",
"EnergyForcesLoss",
"WeightedEnergyForcesLoss",
"WeightedForcesLoss",
"WeightedEnergyForcesVirialsLoss",
Expand Down
24 changes: 0 additions & 24 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,6 @@ def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Te
# return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # []


class EnergyForcesLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
return self.energy_weight * mean_squared_error_energy(
ref, pred
) + self.forces_weight * mean_squared_error_forces(ref, pred)

def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
)


class WeightedEnergyForcesLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
Expand Down
5 changes: 2 additions & 3 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ def main() -> None:
dipole_weight=args.dipole_weight,
)
else:
loss_fn = modules.EnergyForcesLoss(
energy_weight=args.energy_weight, forces_weight=args.forces_weight
)
# Unweighted Energy and Forces loss by default
loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0)
logging.info(loss_fn)

if args.compute_avg_num_neighbors:
Expand Down

0 comments on commit 538b03c

Please sign in to comment.