From f721a2fd7aef94bf91bf59529239b7d35700e2b4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 10 Apr 2022 23:34:18 +0800 Subject: [PATCH] Minor fixes for logging (#296) * Minor fixes for logging * Minor fix --- .../ASR/pruned_transducer_stateless/train.py | 36 ++++++++++-------- icefall/utils.py | 37 +++++++++++++------ 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 17f82e6013..e743106ece 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -609,21 +609,6 @@ def maybe_log_weights(tag: str): global_step=params.batch_idx_train, ) - def maybe_log_param_relative_changes(): - if ( - params.log_diagnostics - and tb_writer is not None - and params.batch_idx_train % (params.log_interval * 5) == 0 - ): - deltas = optim_step_and_measure_param_change(model, optimizer) - tb_writer.add_scalars( - "train/relative_param_change_per_minibatch", - deltas, - global_step=params.batch_idx_train, - ) - else: - optimizer.step() - cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): @@ -651,7 +636,26 @@ def maybe_log_param_relative_changes(): maybe_log_weights("train/param_norms") maybe_log_gradients("train/grad_norms") - maybe_log_param_relative_changes() + + old_parameters = None + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + + optimizer.step() + + if old_parameters is not None: + deltas = optim_step_and_measure_param_change(model, old_parameters) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index c231dbbe47..daccd43466 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -25,15 +25,14 @@ from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union +from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.version import kaldialign import torch -import torch.nn as nn import torch.distributed as dist -from torch.cuda.amp import GradScaler +import torch.nn as nn from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -758,11 +757,10 @@ def measure_gradient_norms( def optim_step_and_measure_param_change( model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: Optional[GradScaler] = None, + old_parameters: Dict[str, nn.parameter.Parameter], ) -> Dict[str, float]: """ - Perform model weight update and measure the "relative change in parameters per minibatch." + Measure the "relative change in parameters per minibatch." It is understood as a ratio between the L2 norm of the difference between original and updates parameters, and the L2 norm of the original parameter. It is given by the formula: @@ -770,16 +768,31 @@ def optim_step_and_measure_param_change( \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} - """ - param_copy = {n: p.detach().clone() for n, p in model.named_parameters()} - if scaler: - scaler.step(optimizer) - else: + + This function is supposed to be used as follows: + + .. code-block:: python + + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + optimizer.step() + + deltas = optim_step_and_measure_param_change(old_parameters) + + Args: + model: A torch.nn.Module instance. + old_parameters: + A Dict of named_parameters before optimizer.step(). + + Return: + A Dict containing the relative change for each parameter. + """ relative_change = {} with torch.no_grad(): for n, p_new in model.named_parameters(): - p_orig = param_copy[n] + p_orig = old_parameters[n] delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change