From a5693c511566e27f385ccc53c769f848f303f19e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 09:46:43 +0300 Subject: [PATCH] Apply manual ruff fixes --- k_diffusion/gns.py | 4 ++-- k_diffusion/models/axial_rope.py | 8 +++++++- k_diffusion/utils.py | 21 +++++++++++++++------ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/k_diffusion/gns.py b/k_diffusion/gns.py index 98d3bc9..ac95468 100644 --- a/k_diffusion/gns.py +++ b/k_diffusion/gns.py @@ -5,8 +5,8 @@ class DDPGradientStatsHook: def __init__(self, ddp_module): try: ddp_module.register_comm_hook(self, self._hook_fn) - except AttributeError: - raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') + except AttributeError as ae: + raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') from ae self._clear_state() def _clear_state(self): diff --git a/k_diffusion/models/axial_rope.py b/k_diffusion/models/axial_rope.py index ba105cb..f8e3b6d 100644 --- a/k_diffusion/models/axial_rope.py +++ b/k_diffusion/models/axial_rope.py @@ -91,7 +91,13 @@ def init(shape): class AxialRoPE(nn.Module): - def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)): + def __init__( + self, + dim, + n_heads, + start_index=0, + freqs_init=freqs_pixel_log(max_freq=10.0), # noqa: B008 + ): super().__init__() self.n_heads = n_heads self.start_index = start_index diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 311d961..654dcc6 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -180,8 +180,11 @@ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., def get_lr(self): if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + stacklevel=1, + ) return self._get_closed_form_lr() @@ -221,8 +224,11 @@ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., def get_lr(self): if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + stacklevel=1, + ) return self._get_closed_form_lr() @@ -253,8 +259,11 @@ def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False): def get_lr(self): if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + stacklevel=1, + ) return self._get_closed_form_lr()