Skip to content

Commit

Permalink
Merge pull request #443 from ACEsuit/no_stress_per_atom
Browse files Browse the repository at this point in the history
get rid of all stress/n_atoms
  • Loading branch information
ilyes319 authored Jun 6, 2024
2 parents 14ec4ec + d523746 commit 346999c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 54 deletions.
3 changes: 1 addition & 2 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,]
return torch.mean(
configs_weight
* configs_stress_weight
* torch.square((ref["stress"] - pred["stress"]) / num_atoms)
* torch.square(ref["stress"] - pred["stress"])
) # []


Expand Down
13 changes: 3 additions & 10 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def valid_err_log(
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_stress_per_atom"] is not None
and eval_metrics["rmse_stress"] is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand Down Expand Up @@ -405,7 +405,6 @@ def __init__(self, loss_fn: torch.nn.Module):
"stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state("delta_stress", default=[], dist_reduce_fx="cat")
self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat")
self.add_state(
"virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
Expand Down Expand Up @@ -434,10 +433,6 @@ def update(self, batch, output): # pylint: disable=arguments-differ
if output.get("stress") is not None and batch.stress is not None:
self.stress_computed += 1.0
self.delta_stress.append(batch.stress - output["stress"])
self.delta_stress_per_atom.append(
(batch.stress - output["stress"])
/ (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1)
)
if output.get("virials") is not None and batch.virials is not None:
self.virials_computed += 1.0
self.delta_virials.append(batch.virials - output["virials"])
Expand Down Expand Up @@ -480,10 +475,8 @@ def compute(self):
aux["q95_f"] = compute_q95(delta_fs)
if self.stress_computed:
delta_stress = self.convert(self.delta_stress)
delta_stress_per_atom = self.convert(self.delta_stress_per_atom)
aux["mae_stress"] = compute_mae(delta_stress)
aux["rmse_stress"] = compute_rmse(delta_stress)
aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom)
aux["q95_stress"] = compute_q95(delta_stress)
if self.virials_computed:
delta_virials = self.convert(self.delta_virials)
Expand Down
84 changes: 42 additions & 42 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,30 @@ def test_run_train(tmp_path, fitting_configs):
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3
# from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7
ref_Es = [
0.0,
0.0,
-0.03911274694160493,
-0.0913651377675312,
-0.14973695873658766,
-0.0664839502025434,
-0.09968814898703926,
0.1248460531971883,
-0.0647495831154953,
-0.14589298347245963,
0.12918668431788108,
-0.13996496272772996,
-0.053211348522482806,
0.07845141245421094,
-0.08901520083723416,
-0.15467129065263446,
0.007727727865546765,
-0.04502061132025605,
-0.035848783030374,
-0.24410687104937906,
-0.0839034724949955,
-0.14756571357354326,
-0.039181344585828524,
-0.0915223395136733,
-0.14953484236456582,
-0.06662480820063998,
-0.09983737353050133,
0.12477442296789745,
-0.06486086271762856,
-0.1460607988519944,
0.12886334908465508,
-0.14000990081920373,
-0.05319886578958313,
0.07780520158391,
-0.08895480281886901,
-0.15474719614734422,
0.007756765146527644,
-0.044879267197498685,
-0.036065736712447574,
-0.24413743841886623,
-0.0838104612106429,
-0.14751978636626545
]

assert np.allclose(Es, ref_Es)
Expand Down Expand Up @@ -178,30 +178,30 @@ def test_run_train_missing_data(tmp_path, fitting_configs):
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3
# from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7
ref_Es = [
0.0,
0.0,
-0.05449966431966507,
-0.11237663925685797,
0.03914539466246801,
-0.07500800414261456,
-0.13471106701173396,
0.02937255038020199,
-0.0652196693921633,
-0.14946129637190012,
0.19412338220281133,
-0.13546947741234333,
-0.05235148626886153,
-0.04957190959243316,
-0.07081384032242896,
-0.24575839901841345,
-0.0020512332640394916,
-0.038630330106902526,
-0.13621347044601181,
-0.2338465954158298,
-0.11777474787291177,
-0.14895508008918812,
-0.05464025113696155,
-0.11272131295940478,
0.039200919331076826,
-0.07517990972827505,
-0.13504202474582666,
0.0292022872055344,
-0.06541099574579018,
-0.1497824717832886,
0.19397709360828813,
-0.13587609467143014,
-0.05242956276828463,
-0.0504862057364953,
-0.07095795959430119,
-0.2463753796753703,
-0.002031543147676121,
-0.03864918790300681,
-0.13680153117705554,
-0.23418951968636786,
-0.11790833839379238,
-0.14930562311066484
]
assert np.allclose(Es, ref_Es)

Expand Down

0 comments on commit 346999c

Please sign in to comment.