Skip to content

Commit

Permalink
Apply manual ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Sep 26, 2023
1 parent 7bb73e7 commit a5693c5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
4 changes: 2 additions & 2 deletions k_diffusion/gns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
except AttributeError as ae:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') from ae
self._clear_state()

def _clear_state(self):
Expand Down
8 changes: 7 additions & 1 deletion k_diffusion/models/axial_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ def init(shape):


class AxialRoPE(nn.Module):
def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
def __init__(
self,
dim,
n_heads,
start_index=0,
freqs_init=freqs_pixel_log(max_freq=10.0), # noqa: B008
):
super().__init__()
self.n_heads = n_heads
self.start_index = start_index
Expand Down
21 changes: 15 additions & 6 deletions k_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
stacklevel=1,
)

return self._get_closed_form_lr()

Expand Down Expand Up @@ -221,8 +224,11 @@ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
stacklevel=1,
)

return self._get_closed_form_lr()

Expand Down Expand Up @@ -253,8 +259,11 @@ def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False):

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
stacklevel=1,
)

return self._get_closed_form_lr()

Expand Down

0 comments on commit a5693c5

Please sign in to comment.