From 0842e7c5cea983e51e58e3275ce83a40c67539ce Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Tue, 4 Jun 2024 16:51:24 -0400 Subject: [PATCH 1/3] get rid of all stress/n_atoms --- mace/tools/train.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 32231acf..9306a171 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -60,13 +60,13 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress_per_atom"] is not None + and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -405,7 +405,6 @@ def __init__(self, loss_fn: torch.nn.Module): "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") self.add_state( "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) @@ -434,10 +433,6 @@ def update(self, batch, output): # pylint: disable=arguments-differ if output.get("stress") is not None and batch.stress is not None: self.stress_computed += 1.0 self.delta_stress.append(batch.stress - output["stress"]) - self.delta_stress_per_atom.append( - (batch.stress - output["stress"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) if output.get("virials") is not None and batch.virials is not None: self.virials_computed += 1.0 self.delta_virials.append(batch.virials - output["virials"]) @@ -480,10 +475,8 @@ def compute(self): aux["q95_f"] = compute_q95(delta_fs) if self.stress_computed: delta_stress = self.convert(self.delta_stress) - delta_stress_per_atom = self.convert(self.delta_stress_per_atom) aux["mae_stress"] = compute_mae(delta_stress) aux["rmse_stress"] = compute_rmse(delta_stress) - aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) aux["q95_stress"] = compute_q95(delta_stress) if self.virials_computed: delta_virials = self.convert(self.delta_virials) From a7b32da7a2fa8da54f62c280d8c51a02fe3abc71 Mon Sep 17 00:00:00 2001 From: JamesDarby Date: Mon, 3 Jun 2024 22:46:28 +0000 Subject: [PATCH 2/3] remove n_atoms factor --- mace/modules/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index d1f8becd..b3421ef5 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: # energy: [n_graphs, ] configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] return torch.mean( configs_weight * configs_stress_weight - * torch.square((ref["stress"] - pred["stress"]) / num_atoms) + * torch.square(ref["stress"] - pred["stress"]) ) # [] From d52374616dc0e7ab61c140a72b58bea0f8211155 Mon Sep 17 00:00:00 2001 From: James Darby Date: Tue, 4 Jun 2024 11:23:55 +0100 Subject: [PATCH 3/3] updated tests --- tests/test_run_train.py | 84 ++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 7dca8919..7109744d 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -107,30 +107,30 @@ def test_run_train(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 ref_Es = [ 0.0, 0.0, - -0.03911274694160493, - -0.0913651377675312, - -0.14973695873658766, - -0.0664839502025434, - -0.09968814898703926, - 0.1248460531971883, - -0.0647495831154953, - -0.14589298347245963, - 0.12918668431788108, - -0.13996496272772996, - -0.053211348522482806, - 0.07845141245421094, - -0.08901520083723416, - -0.15467129065263446, - 0.007727727865546765, - -0.04502061132025605, - -0.035848783030374, - -0.24410687104937906, - -0.0839034724949955, - -0.14756571357354326, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545 ] assert np.allclose(Es, ref_Es) @@ -178,30 +178,30 @@ def test_run_train_missing_data(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 ref_Es = [ 0.0, 0.0, - -0.05449966431966507, - -0.11237663925685797, - 0.03914539466246801, - -0.07500800414261456, - -0.13471106701173396, - 0.02937255038020199, - -0.0652196693921633, - -0.14946129637190012, - 0.19412338220281133, - -0.13546947741234333, - -0.05235148626886153, - -0.04957190959243316, - -0.07081384032242896, - -0.24575839901841345, - -0.0020512332640394916, - -0.038630330106902526, - -0.13621347044601181, - -0.2338465954158298, - -0.11777474787291177, - -0.14895508008918812, + -0.05464025113696155, + -0.11272131295940478, + 0.039200919331076826, + -0.07517990972827505, + -0.13504202474582666, + 0.0292022872055344, + -0.06541099574579018, + -0.1497824717832886, + 0.19397709360828813, + -0.13587609467143014, + -0.05242956276828463, + -0.0504862057364953, + -0.07095795959430119, + -0.2463753796753703, + -0.002031543147676121, + -0.03864918790300681, + -0.13680153117705554, + -0.23418951968636786, + -0.11790833839379238, + -0.14930562311066484 ] assert np.allclose(Es, ref_Es)