From 14d53f9cbf1566c26121fcd0ba096569bf99fbed Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Mar 2021 14:03:16 -0600 Subject: [PATCH 1/3] No grad loss saving --- nequip/train/trainer.py | 89 +++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 39 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 3c5d6faf..3efa4c0d 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -574,16 +574,17 @@ def batch_step(self, data, n_batches, validation=False): self.model.train() - data = data.to(self.device) - data = AtomicData.to_AtomicDataDict(data) - if hasattr(self.model, "unscale"): - # This means that self.model is RescaleOutputs - # this will normalize the targets - # in validation (eval mode), it does nothing - # in train mode, if normalizes the targets - data = self.model.unscale(data) - - out = self.model(data) + with torch.no_grad(): + data = data.to(self.device) + data = AtomicData.to_AtomicDataDict(data) + if hasattr(self.model, "unscale"): + # This means that self.model is RescaleOutputs + # this will normalize the targets + # in validation (eval mode), it does nothing + # in train mode, if normalizes the targets + data = self.model.unscale(data) + + out = self.model(data) # If we're in evaluation mode (i.e. validation), then # data's target prop is unnormalized, and out's has been rescaled to be in the same units @@ -593,7 +594,6 @@ def batch_step(self, data, n_batches, validation=False): loss, loss_contrib = self.loss(pred=out, ref=data) if not validation: - self.optim.zero_grad() loss.backward() self.optim.step() @@ -601,30 +601,24 @@ def batch_step(self, data, n_batches, validation=False): if self.lr_scheduler_name == "CosineAnnealingWarmRestarts": self.lr_sched.step(self.iepoch + self.ibatch / n_batches) - mae, mae_contrib = self.loss.mae(pred=out, ref=data) - scaled_loss_contrib = {} - if hasattr(self.model, "scale"): + # save loss stats + with torch.no_grad(): + mae, mae_contrib = self.loss.mae(pred=out, ref=data) + scaled_loss_contrib = {} + if hasattr(self.model, "scale"): - for key in mae_contrib: - mae_contrib[key] = self.model.scale( - mae_contrib[key], force_process=True, do_shift=False - ) - - # TO DO, this evetually needs to be removed. no guarantee that a loss is MSE - for key in loss_contrib: + for key in mae_contrib: + mae_contrib[key] = self.model.scale( + mae_contrib[key], force_process=True, do_shift=False + ) - scaled_loss_contrib[key] = { - k: torch.clone(v) for k, v in loss_contrib[key].items() - } + # TO DO, this evetually needs to be removed. no guarantee that a loss is MSE + for key in loss_contrib: - scaled_loss_contrib[key] = self.model.scale( - scaled_loss_contrib[key], - force_process=True, - do_shift=False, - do_scale=True, - ) + scaled_loss_contrib[key] = { + k: torch.clone(v) for k, v in loss_contrib[key].items() + } - if "mse" in type(self.loss.funcs[key].func).__name__.lower(): scaled_loss_contrib[key] = self.model.scale( scaled_loss_contrib[key], force_process=True, @@ -632,15 +626,32 @@ def batch_step(self, data, n_batches, validation=False): do_scale=True, ) - self.batch_loss = loss - self.batch_scaled_loss_contrib = scaled_loss_contrib - self.batch_loss_contrib = loss_contrib - self.batch_mae = mae - self.batch_mae_contrib = mae_contrib + if "mse" in type(self.loss.funcs[key].func).__name__.lower(): + scaled_loss_contrib[key] = self.model.scale( + scaled_loss_contrib[key], + force_process=True, + do_shift=False, + do_scale=True, + ) + + self.batch_loss = loss.detach() + self.batch_scaled_loss_contrib = { + k1: {k2: v2.detach() for k2, v2 in v1.items()} + for k1, v1 in scaled_loss_contrib.items() + } + self.batch_loss_contrib = { + k1: {k2: v2.detach() for k2, v2 in v1.items()} + for k1, v1 in loss_contrib.items() + } + self.batch_mae = mae.detach() + self.batch_mae_contrib = { + k1: {k2: v2.detach() for k2, v2 in v1.items()} + for k1, v1 in mae_contrib.items() + } - self.end_of_batch_log(validation) - for callback in self.end_of_batch_callbacks: - callback(self) + self.end_of_batch_log(validation) + for callback in self.end_of_batch_callbacks: + callback(self) @property def early_stop_cond(self): From f3cb4975cd3537e9be05d170281acb4e1d25a9f9 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Mar 2021 14:12:37 -0600 Subject: [PATCH 2/3] Comments --- nequip/train/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 3efa4c0d..a6f35643 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -574,6 +574,7 @@ def batch_step(self, data, n_batches, validation=False): self.model.train() + # Do any target rescaling with torch.no_grad(): data = data.to(self.device) data = AtomicData.to_AtomicDataDict(data) @@ -584,7 +585,8 @@ def batch_step(self, data, n_batches, validation=False): # in train mode, if normalizes the targets data = self.model.unscale(data) - out = self.model(data) + # Run model + out = self.model(data) # If we're in evaluation mode (i.e. validation), then # data's target prop is unnormalized, and out's has been rescaled to be in the same units From c3b0408fe650591a9e2a09c0bf41d365a1d60738 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Mar 2021 15:20:01 -0600 Subject: [PATCH 3/3] Remove unnecessary no_grad() --- nequip/train/trainer.py | 17 ++++++++--------- scripts/train.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index a6f35643..0157a9e0 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -575,15 +575,14 @@ def batch_step(self, data, n_batches, validation=False): self.model.train() # Do any target rescaling - with torch.no_grad(): - data = data.to(self.device) - data = AtomicData.to_AtomicDataDict(data) - if hasattr(self.model, "unscale"): - # This means that self.model is RescaleOutputs - # this will normalize the targets - # in validation (eval mode), it does nothing - # in train mode, if normalizes the targets - data = self.model.unscale(data) + data = data.to(self.device) + data = AtomicData.to_AtomicDataDict(data) + if hasattr(self.model, "unscale"): + # This means that self.model is RescaleOutputs + # this will normalize the targets + # in validation (eval mode), it does nothing + # in train mode, if normalizes the targets + data = self.model.unscale(data) # Run model out = self.model(data) diff --git a/scripts/train.py b/scripts/train.py index 65ae8b7b..110679c7 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -39,7 +39,7 @@ def main(): # Get statistics of training dataset ( - (forces_std), + (forces_std,), (energies_mean, energies_std), (allowed_species, Z_count), ) = trainer.dataset_train.statistics(