Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer v2] Load + evaluate FID for OpenAI guided-diffusion models #84

Open
wants to merge 7 commits into
base: transformer-model-v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"configurations": [
{
"name": "Python: Compute FID OpenAI",
"type": "python",
"request": "launch",
"program": "train.py",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--config", "configs/config_guided_diffusion_imagenet.json",
"--resume-inference", "/nvme1/ml-weights/256x256_diffusion.pt",
"--evaluate-only",
"--evaluate-n", "4",
"--start-method", "fork"
]
}
]
}
27 changes: 27 additions & 0 deletions configs/config_guided_diffusion_imagenet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"model": {
"type": "guided_diffusion",
"config": {
"attention_resolutions": "32, 16, 8",
"class_cond": true,
"diffusion_steps": 1000,
"image_size": 256,
"learn_sigma": true,
"noise_schedule": "linear",
"num_channels": 256,
"num_head_channels": 64,
"num_res_blocks": 2,
"resblock_updown": true,
"use_fp16": true,
"use_scale_shift_norm": true,
"use_torch_sdp_attention": true
},
"input_channels": 3,
"input_size": [256, 256]
},
"dataset": {
"type": "imagefolder-class",
"location": "/nvme1/ml-data/ImageNet/train",
"num_classes": 1000
}
}
19 changes: 19 additions & 0 deletions k_diffusion/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,25 @@ def forward(self, input, sigma, **kwargs):
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip


class OpenAIVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for OpenAI v objective diffusion models."""

def __init__(
self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu"
):
alphas_cumprod = torch.tensor(
diffusion.alphas_cumprod, device=device, dtype=torch.float32
)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas

def get_v(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output


class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""

Expand Down
5 changes: 5 additions & 0 deletions k_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('in
return u.logit().mul(scale).add(loc).exp().to(dtype)


def rand_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from a uniform distribution."""
return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value)


def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
Expand Down
44 changes: 44 additions & 0 deletions kdiff_trainer/load_diffusion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pathlib import Path
import torch
import safetensors
from typing import Literal, Dict, NamedTuple
import k_diffusion as K

from guided_diffusion import script_util
from guided_diffusion.unet import UNetModel
from guided_diffusion.respace import SpacedDiffusion

class ModelAndDiffusion(NamedTuple):
model: UNetModel
diffusion: SpacedDiffusion

def construct_diffusion_model(config_overrides: Dict = {}) -> ModelAndDiffusion:
model_config = script_util.model_and_diffusion_defaults()
if config_overrides:
model_config.update(config_overrides)
model, diffusion = script_util.create_model_and_diffusion(**model_config)
return ModelAndDiffusion(model, diffusion)

def load_diffusion_model(
model_path: str,
model: UNetModel,
):
if Path(model_path).suffix == ".safetensors":
safetensors.torch.load_model(model, model_path)
else:
model.load_state_dict(torch.load(model_path, map_location="cpu"))
if model.dtype is torch.float16:
model.convert_to_fp16()

def wrap_diffusion_model(
model: UNetModel,
diffusion: SpacedDiffusion,
device="cpu",
model_type: Literal['eps', 'v'] = "eps"
):
if model_type == "eps":
return K.external.OpenAIDenoiser(model, diffusion, device=device)
elif model_type == "v":
return K.external.OpenAIVDenoiser(model, diffusion, device=device)
else:
raise ValueError(f"Unknown model type {model_type}")
1 change: 1 addition & 0 deletions requirements.oai.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
guided-diffusion @ git+https://github.com/Birch-san/guided-diffusion.git@torch-sdp#egg=guided-diffusion
Loading