Skip to content

Commit

Permalink
Apply Ruff autofixes
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jul 26, 2023
1 parent bf3f743 commit b61735a
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 49 deletions.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
Expand Down Expand Up @@ -406,7 +406,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
# batch_idx > 5 and
self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
Expand Down Expand Up @@ -652,7 +651,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):

ckpt_resume_path = opt.resume_from_checkpoint

if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
if "devices" not in trainer_config and trainer_config["accelerator"] != "gpu":
del trainer_config["accelerator"]
cpu = True
else:
Expand Down Expand Up @@ -818,7 +817,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
if not "plugins" in trainer_kwargs:
if "plugins" not in trainer_kwargs:
trainer_kwargs["plugins"] = list()

# cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
Expand Down Expand Up @@ -910,11 +909,12 @@ def divein(*args, **kwargs):
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket

import requests

device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
Expand Down
26 changes: 12 additions & 14 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
import math
import os
from typing import Union, List
from typing import List, Union

import math
import numpy as np
import streamlit as st
import torch
from PIL import Image
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import OmegaConf, ListConfig
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors

from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims
from sgm.util import instantiate_from_config
from sgm.util import append_dims, instantiate_from_config


class WatermarkEmbedder:
Expand Down Expand Up @@ -74,7 +72,7 @@ def __call__(self, image: torch.Tensor):
@st.cache_resource()
def init_st(version_dict, load_ckpt=True):
state = dict()
if not "model" in state:
if "model" not in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]

Expand Down Expand Up @@ -224,12 +222,12 @@ def __init__(self, discretization, strength: float = 1.0):
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
print("sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
print("sigmas after pruning: ", sigmas)
return sigmas


Expand Down
16 changes: 8 additions & 8 deletions scripts/tests/attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
import einops
from torch.backends.cuda import SDPBackend
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torch.backends.cuda import SDPBackend

from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer


def benchmark_attn():
Expand Down Expand Up @@ -51,7 +51,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
dtype=dtype,
)

print(f"q/k/v shape:", query.shape, key.shape, value.shape)
print("q/k/v shape:", query.shape, key.shape, value.shape)

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel
Expand Down Expand Up @@ -87,7 +87,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print(
Expand All @@ -99,7 +99,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
Expand All @@ -114,7 +114,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
Expand All @@ -129,7 +129,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


Expand Down
5 changes: 3 additions & 2 deletions scripts/util/detection/nsfw_and_watermark_dectection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import torch

import clip
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import clip

RESOURCES_ROOT = "scripts/util/detection/"

Expand Down
2 changes: 1 addition & 1 deletion sgm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config, get_configs_path
from .util import get_configs_path, instantiate_from_config

__version__ = "0.0.1"
4 changes: 2 additions & 2 deletions sgm/data/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CIFAR10DataDictWrapper(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion sgm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
except ImportError:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
Expand Down
4 changes: 2 additions & 2 deletions sgm/data/mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class MNISTDataDictWrapper(Dataset):
Expand Down
5 changes: 2 additions & 3 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,8 @@ def sample(
):
randn = torch.randn(batch_size, *shape).to(self.device)

denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
def denoiser(input, sigma, c):
return self.denoiser(self.model, input, sigma, c, **kwargs)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples

Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/diffusionmodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Model, Encoder, Decoder
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper
9 changes: 5 additions & 4 deletions sgm/modules/diffusionmodules/discretizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
import numpy as np
from functools import partial
from abc import abstractmethod
from functools import partial

import numpy as np
import torch

from ...util import append_zero
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero


def generate_roughly_equally_spaced_steps(
Expand Down
3 changes: 2 additions & 1 deletion sgm/modules/diffusionmodules/guiders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class VanillaCFG:
"""

def __init__(self, scale, dyn_thresh_config=None):
scale_schedule = lambda scale, sigma: scale # independent of step
def scale_schedule(scale, sigma):
return scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def __init__(
self.tanh_out = tanh_out

# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
(1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/distributions/distributions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import numpy as np
import torch


class AbstractDistribution:
Expand Down
4 changes: 2 additions & 2 deletions sgm/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, model):
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name

def copy_to(self, model):
m_param = dict(model.named_parameters())
Expand All @@ -60,7 +60,7 @@ def copy_to(self, model):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name

def store(self, parameters):
"""
Expand Down
2 changes: 1 addition & 1 deletion sgm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def count_params(model, verbose=False):


def instantiate_from_config(config):
if not "target" in config:
if "target" not in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
Expand Down

0 comments on commit b61735a

Please sign in to comment.