From 51a5310cb144b1f3903a9a9993a4cdf571a04d3a Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 22:25:33 +0100 Subject: [PATCH 1/7] support loading OpenAI guided-diffusion models for evaluation. reduce overhead of evaluate_only mode (skip construction/accelerator-preparation of trainable model, optimizer, schedules) --- .vscode/launch.json | 40 ++++ configs/config_guided_diffusion_imagenet.json | 26 +++ kdiff_trainer/load_diffusion_model.py | 62 ++++++ requirements.oai.txt | 1 + train.py | 208 +++++++++++------- 5 files changed, 259 insertions(+), 78 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 configs/config_guided_diffusion_imagenet.json create mode 100644 kdiff_trainer/load_diffusion_model.py create mode 100644 requirements.oai.txt diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..3c09b57 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,40 @@ +{ + "configurations": [ + { + "name": "Python: Oxford Flowers (shifted window)", + "type": "python", + "request": "launch", + "program": "train.py", + "console": "integratedTerminal", + "justMyCode": true, + "subProcess": false, + "args": [ + "--config", "configs/config_oxford_flowers_shifted_window.json", + "--out-root", "out", + "--output-to-subdir", + "--name", "flowers_demo_001", + "--evaluate-n", "0", + "--batch-size", "32", + "--sample-n", "36", + "--mixed-precision", "bf16", + "--demo-img-compress", + "--font", "./kdiff_trainer/font/DejaVuSansMono.ttf", + ], + }, + { + "name": "Python: Compute FID OpenAI", + "type": "python", + "request": "launch", + "program": "train.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--config", "configs/config_guided_diffusion_imagenet.json", + "--resume-inference", "/nvme1/ml-weights/256x256_diffusion.pt", + "--evaluate-only", + "--evaluate-n", "4", + "--start-method", "fork" + ] + } + ] +} \ No newline at end of file diff --git a/configs/config_guided_diffusion_imagenet.json b/configs/config_guided_diffusion_imagenet.json new file mode 100644 index 0000000..5af7c37 --- /dev/null +++ b/configs/config_guided_diffusion_imagenet.json @@ -0,0 +1,26 @@ +{ + "model": { + "type": "guided_diffusion", + "config": { + "attention_resolutions": "32, 16, 8", + "class_cond": true, + "diffusion_steps": 1000, + "image_size": 256, + "learn_sigma": true, + "noise_schedule": "linear", + "num_channels": 256, + "num_head_channels": 64, + "num_res_blocks": 2, + "resblock_updown": true, + "use_fp16": true, + "use_scale_shift_norm": true + }, + "input_channels": 3, + "input_size": [256, 256] + }, + "dataset": { + "type": "imagefolder-class", + "location": "/nvme1/ml-data/ImageNet/train", + "num_classes": 1000 + } +} \ No newline at end of file diff --git a/kdiff_trainer/load_diffusion_model.py b/kdiff_trainer/load_diffusion_model.py new file mode 100644 index 0000000..3457de9 --- /dev/null +++ b/kdiff_trainer/load_diffusion_model.py @@ -0,0 +1,62 @@ +from pathlib import Path +import torch +import safetensors +from typing import Literal, Dict, NamedTuple +import k_diffusion as K + +from guided_diffusion import script_util +from guided_diffusion.unet import UNetModel +from guided_diffusion.respace import SpacedDiffusion + +class OpenAIVDenoiser(K.external.DiscreteVDDPMDenoiser): + """A wrapper for OpenAI v objective diffusion models.""" + + def __init__( + self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu" + ): + alphas_cumprod = torch.tensor( + diffusion.alphas_cumprod, device=device, dtype=torch.float32 + ) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_v(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + +class ModelAndDiffusion(NamedTuple): + model: UNetModel + diffusion: SpacedDiffusion + +def construct_diffusion_model(config_overrides: Dict = {}) -> ModelAndDiffusion: + model_config = script_util.model_and_diffusion_defaults() + if config_overrides: + model_config.update(config_overrides) + model, diffusion = script_util.create_model_and_diffusion(**model_config) + return ModelAndDiffusion(model, diffusion) + +def load_diffusion_model( + model_path: str, + model: UNetModel, +): + if Path(model_path).suffix == ".safetensors": + safetensors.torch.load_model(model, model_path) + else: + model.load_state_dict(torch.load(model_path, map_location="cpu")) + if model.dtype is torch.float16: + model.convert_to_fp16() + +def wrap_diffusion_model( + model: UNetModel, + diffusion: SpacedDiffusion, + device="cpu", + model_type: Literal['eps', 'v'] = "eps" +): + if model_type == "eps": + return K.external.OpenAIDenoiser(model, diffusion, device=device) + elif model_type == "v": + return OpenAIVDenoiser(model, diffusion, device=device) + else: + raise ValueError(f"Unknown model type {model_type}") \ No newline at end of file diff --git a/requirements.oai.txt b/requirements.oai.txt new file mode 100644 index 0000000..3d1156e --- /dev/null +++ b/requirements.oai.txt @@ -0,0 +1 @@ +guided-diffusion @ git+https://github.com/crowsonkb/guided-diffusion#egg=guided-diffusion \ No newline at end of file diff --git a/train.py b/train.py index 6d85de1..f8e611e 100755 --- a/train.py +++ b/train.py @@ -18,9 +18,12 @@ from torch import distributed as dist from torch import multiprocessing as mp from torch import optim +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils import data from torchvision import datasets, transforms, utils from tqdm.auto import tqdm +from typing import Optional import k_diffusion as K @@ -104,12 +107,17 @@ def main(): except AttributeError: pass + do_train = not args.evaluate_only + config = K.config.load_config(args.config) model_config = config['model'] dataset_config = config['dataset'] - opt_config = config['optimizer'] - sched_config = config['lr_sched'] - ema_sched_config = config['ema_sched'] + if do_train: + opt_config = config['optimizer'] + sched_config = config['lr_sched'] + ema_sched_config = config['ema_sched'] + else: + opt_config = sched_config = ema_sched_config = None # TODO: allow non-square input sizes assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] @@ -131,15 +139,26 @@ def main(): demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) elapsed = 0.0 - inner_model = K.config.make_model(config) - inner_model_ema = deepcopy(inner_model) + if model_config['type'] == 'guided_diffusion': + from kdiff_trainer.load_diffusion_model import construct_diffusion_model + # can't easily put this into K.config.make_model; would change return type and introduce dependency + model_, guided_diff = construct_diffusion_model(model_config['config']) + else: + model_ = K.config.make_model(config) + guided_diff = None + + if do_train: + inner_model, inner_model_ema = model_, deepcopy(model_) + else: + inner_model, inner_model_ema = None, model_ + del model_ if args.compile: - inner_model.compile() + (inner_model or inner_model_ema).compile() # inner_model_ema.compile() if accelerator.is_main_process: - print(f'Parameters: {K.utils.n_params(inner_model):,}') + print(f'Parameters: {K.utils.n_params((inner_model or inner_model_ema)):,}') # If logging to wandb, initialize the run use_wandb = accelerator.is_main_process and args.wandb_project @@ -147,52 +166,57 @@ def main(): import wandb log_config = vars(args) log_config['config'] = config - log_config['parameters'] = K.utils.n_params(inner_model) + log_config['parameters'] = K.utils.n_params((inner_model or inner_model_ema)) wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) - lr = opt_config['lr'] if args.lr is None else args.lr - groups = inner_model.param_groups(lr) - if opt_config['type'] == 'adamw': - opt = optim.AdamW(groups, - lr=lr, - betas=tuple(opt_config['betas']), - eps=opt_config['eps'], - weight_decay=opt_config['weight_decay']) - elif opt_config['type'] == 'adam8bit': - import bitsandbytes as bnb - opt = bnb.optim.Adam8bit(groups, - lr=lr, - betas=tuple(opt_config['betas']), - eps=opt_config['eps'], - weight_decay=opt_config['weight_decay']) - elif opt_config['type'] == 'sgd': - opt = optim.SGD(groups, - lr=lr, - momentum=opt_config.get('momentum', 0.), - nesterov=opt_config.get('nesterov', False), - weight_decay=opt_config.get('weight_decay', 0.)) - else: - raise ValueError('Invalid optimizer type') - - if sched_config['type'] == 'inverse': - sched = K.utils.InverseLR(opt, - inv_gamma=sched_config['inv_gamma'], - power=sched_config['power'], - warmup=sched_config['warmup']) - elif sched_config['type'] == 'exponential': - sched = K.utils.ExponentialLR(opt, - num_steps=sched_config['num_steps'], - decay=sched_config['decay'], - warmup=sched_config['warmup']) - elif sched_config['type'] == 'constant': - sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) - else: - raise ValueError('Invalid schedule type') + if do_train: + lr = opt_config['lr'] if args.lr is None else args.lr + groups = inner_model.param_groups(lr) + if opt_config['type'] == 'adamw': + opt = optim.AdamW(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'adam8bit': + import bitsandbytes as bnb + opt = bnb.optim.Adam8bit(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'sgd': + opt = optim.SGD(groups, + lr=lr, + momentum=opt_config.get('momentum', 0.), + nesterov=opt_config.get('nesterov', False), + weight_decay=opt_config.get('weight_decay', 0.)) + else: + raise ValueError('Invalid optimizer type') + + if sched_config['type'] == 'inverse': + sched = K.utils.InverseLR(opt, + inv_gamma=sched_config['inv_gamma'], + power=sched_config['power'], + warmup=sched_config['warmup']) + elif sched_config['type'] == 'exponential': + sched = K.utils.ExponentialLR(opt, + num_steps=sched_config['num_steps'], + decay=sched_config['decay'], + warmup=sched_config['warmup']) + elif sched_config['type'] == 'constant': + sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) + else: + raise ValueError('Invalid schedule type') - assert ema_sched_config['type'] == 'inverse' - ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], - max_value=ema_sched_config['max_value']) - ema_stats = {} + assert ema_sched_config['type'] == 'inverse' + ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], + max_value=ema_sched_config['max_value']) + ema_stats = {} + else: + opt: Optional[Optimizer] = None + sched: Optional[LRScheduler] = None + ema_sched: Optional[K.utils.EMAWarmup] = None tf = transforms.Compose([ transforms.Resize(size[0], interpolation=transforms.InterpolationMode.BICUBIC), @@ -238,22 +262,38 @@ def main(): train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, persistent_workers=True, pin_memory=True) - inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) - if use_wandb: - wandb.watch(inner_model) + if do_train: + inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) + if use_wandb: + wandb.watch(inner_model) + else: + inner_model_ema, train_dl = accelerator.prepare(inner_model_ema, train_dl) + if accelerator.num_processes == 1: args.gns = False - if args.gns: + if args.gns and do_train: gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) gns_stats = K.gns.GradientNoiseScale() else: gns_stats = None - sigma_min = model_config['sigma_min'] - sigma_max = model_config['sigma_max'] - sample_density = K.config.make_sample_density(model_config) - - model = K.config.make_denoiser_wrapper(config)(inner_model) - model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) + if guided_diff is None: + sigma_min = model_config['sigma_min'] + sigma_max = model_config['sigma_max'] + sample_density = K.config.make_sample_density(model_config) + if do_train: + model = K.config.make_denoiser_wrapper(config)(inner_model) + model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) + else: + from kdiff_trainer.load_diffusion_model import wrap_diffusion_model + if do_train: + model = wrap_diffusion_model(inner_model, guided_diff, device=accelerator.device) + model_ema = wrap_diffusion_model(inner_model_ema, guided_diff, device=accelerator.device) + sigma_min = model_ema.sigma_min.item() + sigma_max = model_ema.sigma_max.item() + # TODO: not sure what this needs to be for guided diffusion + sample_density = None + if not do_train: + model_ema.requires_grad_(False).eval() state_path = Path(f'{args.name}_state.json') @@ -266,18 +306,19 @@ def main(): if accelerator.is_main_process: print(f'Resuming from {ckpt_path}...') ckpt = torch.load(ckpt_path, map_location='cpu') - unwrap(model.inner_model).load_state_dict(ckpt['model']) unwrap(model_ema.inner_model).load_state_dict(ckpt['model_ema']) - opt.load_state_dict(ckpt['opt']) - sched.load_state_dict(ckpt['sched']) - ema_sched.load_state_dict(ckpt['ema_sched']) - ema_stats = ckpt.get('ema_stats', ema_stats) - epoch = ckpt['epoch'] + 1 - step = ckpt['step'] + 1 - if args.gns and ckpt.get('gns_stats', None) is not None: - gns_stats.load_state_dict(ckpt['gns_stats']) - demo_gen.set_state(ckpt['demo_gen']) - elapsed = ckpt.get('elapsed', 0.0) + if do_train: + unwrap(model.inner_model).load_state_dict(ckpt['model']) + opt.load_state_dict(ckpt['opt']) + sched.load_state_dict(ckpt['sched']) + ema_sched.load_state_dict(ckpt['ema_sched']) + ema_stats = ckpt.get('ema_stats', ema_stats) + epoch = ckpt['epoch'] + 1 + step = ckpt['step'] + 1 + if args.gns and ckpt.get('gns_stats', None) is not None: + gns_stats.load_state_dict(ckpt['gns_stats']) + demo_gen.set_state(ckpt['demo_gen']) + elapsed = ckpt.get('elapsed', 0.0) del ckpt else: @@ -285,6 +326,8 @@ def main(): step = 0 if args.reset_ema: + if not do_train: + raise ValueError("Training is disabled (this can happen as a result of options such as --evaluate-only). Accordingly we did not construct a trainable model, and consequently cannot load the EMA model's weights onto said trainable model. Disable --reset-ema, or enable training.") unwrap(model.inner_model).load_state_dict(unwrap(model_ema.inner_model).state_dict()) ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], max_value=ema_sched_config['max_value']) @@ -293,12 +336,19 @@ def main(): if args.resume_inference: if accelerator.is_main_process: print(f'Loading {args.resume_inference}...') - ckpt = safetorch.load_file(args.resume_inference) - unwrap(model.inner_model).load_state_dict(ckpt) - unwrap(model_ema.inner_model).load_state_dict(ckpt) - del ckpt + if guided_diff is None: + ckpt = safetorch.load_file(args.resume_inference) + if do_train: + unwrap(model.inner_model).load_state_dict(ckpt) + unwrap(model_ema.inner_model).load_state_dict(ckpt) + del ckpt + else: + from kdiff_trainer.load_diffusion_model import load_diffusion_model + if do_train: + load_diffusion_model(args.resume_inference, unwrap(model.inner_model)) + load_diffusion_model(args.resume_inference, unwrap(model_ema.inner_model)) - evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 + evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 or args.evaluate_only metrics_log = None if evaluate_enabled: if args.evaluate_with == 'inception': @@ -412,9 +462,11 @@ def save(): wandb.save(filename) if args.evaluate_only: - if not evaluate_enabled: - raise ValueError('--evaluate-only requested but evaluation is disabled') + if args.evaluate_n < 1: + raise ValueError('--evaluate-only requested but evaluate_n is less than 1') evaluate() + if accelerator.is_main_process: + tqdm.write('Finished evaluating!') return losses_since_last_print = [] From 5ee74be90bc1a50acc9167a0a67716c1c537a1eb Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 22:25:53 +0100 Subject: [PATCH 2/7] prefer inference_mode context for demo and eval, due to lower overhead --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index f8e611e..68636f1 100755 --- a/train.py +++ b/train.py @@ -382,7 +382,7 @@ def cfg_model_fn(x, sigma, class_cond): return cfg_model_fn return model - @torch.no_grad() + @torch.inference_mode() # note: inference_mode is lower-overhead than no_grad but disables forward-mode AD @K.utils.eval_mode(model_ema) def demo(): if accelerator.is_main_process: @@ -407,7 +407,7 @@ def demo(): if use_wandb: wandb.log({'demo_grid': wandb.Image(filename)}, step=step) - @torch.no_grad() + @torch.inference_mode() # note: inference_mode is lower-overhead than no_grad but disables forward-mode AD @K.utils.eval_mode(model_ema) def evaluate(): if not evaluate_enabled: From bd8372ec947c6bc9e485cce186a3aa0fdb31ffa6 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 22:50:51 +0100 Subject: [PATCH 3/7] sample density for OpenAI models --- k_diffusion/utils.py | 5 +++++ train.py | 11 +++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 946f2da..6f3f7e5 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -336,6 +336,11 @@ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('in return u.logit().mul(scale).add(loc).exp().to(dtype) +def rand_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): + """Draws samples from a uniform distribution.""" + return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value) + + def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): """Draws samples from an log-uniform distribution.""" min_value = math.log(min_value) diff --git a/train.py b/train.py index 68636f1..0df784c 100755 --- a/train.py +++ b/train.py @@ -279,20 +279,19 @@ def main(): if guided_diff is None: sigma_min = model_config['sigma_min'] sigma_max = model_config['sigma_max'] - sample_density = K.config.make_sample_density(model_config) if do_train: + sample_density = K.config.make_sample_density(model_config) model = K.config.make_denoiser_wrapper(config)(inner_model) model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) else: from kdiff_trainer.load_diffusion_model import wrap_diffusion_model - if do_train: - model = wrap_diffusion_model(inner_model, guided_diff, device=accelerator.device) model_ema = wrap_diffusion_model(inner_model_ema, guided_diff, device=accelerator.device) sigma_min = model_ema.sigma_min.item() sigma_max = model_ema.sigma_max.item() - # TODO: not sure what this needs to be for guided diffusion - sample_density = None - if not do_train: + if do_train: + sample_density = partial(K.utils.rand_uniform, min_value=0, max_value=guided_diff.num_timesteps-1) + model = wrap_diffusion_model(inner_model, guided_diff, device=accelerator.device) + else: model_ema.requires_grad_(False).eval() state_path = Path(f'{args.name}_state.json') From 61f52c65641b066d699f563ab1bb4dec924b2061 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 23:08:25 +0100 Subject: [PATCH 4/7] move OpenAIVDenoiser into k-diffusion library --- k_diffusion/external.py | 19 +++++++++++++++++++ kdiff_trainer/load_diffusion_model.py | 20 +------------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 79b51ce..dd0649c 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -167,6 +167,25 @@ def forward(self, input, sigma, **kwargs): return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip +class OpenAIVDenoiser(DiscreteVDDPMDenoiser): + """A wrapper for OpenAI v objective diffusion models.""" + + def __init__( + self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu" + ): + alphas_cumprod = torch.tensor( + diffusion.alphas_cumprod, device=device, dtype=torch.float32 + ) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_v(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + + class CompVisVDenoiser(DiscreteVDDPMDenoiser): """A wrapper for CompVis diffusion models that output v.""" diff --git a/kdiff_trainer/load_diffusion_model.py b/kdiff_trainer/load_diffusion_model.py index 3457de9..0f4c6f5 100644 --- a/kdiff_trainer/load_diffusion_model.py +++ b/kdiff_trainer/load_diffusion_model.py @@ -8,24 +8,6 @@ from guided_diffusion.unet import UNetModel from guided_diffusion.respace import SpacedDiffusion -class OpenAIVDenoiser(K.external.DiscreteVDDPMDenoiser): - """A wrapper for OpenAI v objective diffusion models.""" - - def __init__( - self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu" - ): - alphas_cumprod = torch.tensor( - diffusion.alphas_cumprod, device=device, dtype=torch.float32 - ) - super().__init__(model, alphas_cumprod, quantize=quantize) - self.has_learned_sigmas = has_learned_sigmas - - def get_v(self, *args, **kwargs): - model_output = self.inner_model(*args, **kwargs) - if self.has_learned_sigmas: - return model_output.chunk(2, dim=1)[0] - return model_output - class ModelAndDiffusion(NamedTuple): model: UNetModel diffusion: SpacedDiffusion @@ -57,6 +39,6 @@ def wrap_diffusion_model( if model_type == "eps": return K.external.OpenAIDenoiser(model, diffusion, device=device) elif model_type == "v": - return OpenAIVDenoiser(model, diffusion, device=device) + return K.external.OpenAIVDenoiser(model, diffusion, device=device) else: raise ValueError(f"Unknown model type {model_type}") \ No newline at end of file From ddfa06087ca89642438d03f70768e58c1cab2df8 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 23:23:35 +0100 Subject: [PATCH 5/7] change class-cond key to match underlying model's expectations --- train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 0df784c..e9ac0c2 100755 --- a/train.py +++ b/train.py @@ -283,6 +283,7 @@ def main(): sample_density = K.config.make_sample_density(model_config) model = K.config.make_denoiser_wrapper(config)(inner_model) model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) + class_cond_key = 'class_cond' else: from kdiff_trainer.load_diffusion_model import wrap_diffusion_model model_ema = wrap_diffusion_model(inner_model_ema, guided_diff, device=accelerator.device) @@ -293,6 +294,7 @@ def main(): model = wrap_diffusion_model(inner_model, guided_diff, device=accelerator.device) else: model_ema.requires_grad_(False).eval() + class_cond_key = 'y' state_path = Path(f'{args.name}_state.json') @@ -395,7 +397,7 @@ def demo(): if num_classes: class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device) dist.broadcast(class_cond, 0) - extra_args['class_cond'] = class_cond[accelerator.process_index] + extra_args[class_cond_key] = class_cond[accelerator.process_index] model_fn = make_cfg_model_fn(model_ema) sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=not accelerator.is_main_process) @@ -418,7 +420,7 @@ def sample_fn(n): x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max model_fn, extra_args = model_ema, {} if num_classes: - extra_args['class_cond'] = torch.randint(0, num_classes, [n], device=device) + extra_args[class_cond_key] = torch.randint(0, num_classes, [n], device=device) model_fn = make_cfg_model_fn(model_ema) x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=True) return x_0 @@ -488,7 +490,7 @@ def save(): class_cond = batch[class_key] drop = torch.rand(class_cond.shape, device=class_cond.device) class_cond.masked_fill_(drop < cond_dropout_rate, num_classes) - extra_args['class_cond'] = class_cond + extra_args[class_cond_key] = class_cond noise = torch.randn_like(reals) with K.utils.enable_stratified_accelerate(accelerator, disable=args.gns): sigma = sample_density([reals.shape[0]], device=device) From a57b6fc65fcbf7919d0df33e77af840a0f027c2d Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 17 Oct 2023 23:26:59 +0100 Subject: [PATCH 6/7] delete launch config not relevant to PR --- .vscode/launch.json | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 3c09b57..b701617 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,26 +1,5 @@ { "configurations": [ - { - "name": "Python: Oxford Flowers (shifted window)", - "type": "python", - "request": "launch", - "program": "train.py", - "console": "integratedTerminal", - "justMyCode": true, - "subProcess": false, - "args": [ - "--config", "configs/config_oxford_flowers_shifted_window.json", - "--out-root", "out", - "--output-to-subdir", - "--name", "flowers_demo_001", - "--evaluate-n", "0", - "--batch-size", "32", - "--sample-n", "36", - "--mixed-precision", "bf16", - "--demo-img-compress", - "--font", "./kdiff_trainer/font/DejaVuSansMono.ttf", - ], - }, { "name": "Python: Compute FID OpenAI", "type": "python", From bfd13d82e0a5e5a28cd73589a85c29c898b72a4e Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 19 Oct 2023 00:58:46 +0100 Subject: [PATCH 7/7] support torch sdp attn for OpenAI guided diffusion --- configs/config_guided_diffusion_imagenet.json | 3 ++- requirements.oai.txt | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/configs/config_guided_diffusion_imagenet.json b/configs/config_guided_diffusion_imagenet.json index 5af7c37..768588c 100644 --- a/configs/config_guided_diffusion_imagenet.json +++ b/configs/config_guided_diffusion_imagenet.json @@ -13,7 +13,8 @@ "num_res_blocks": 2, "resblock_updown": true, "use_fp16": true, - "use_scale_shift_norm": true + "use_scale_shift_norm": true, + "use_torch_sdp_attention": true }, "input_channels": 3, "input_size": [256, 256] diff --git a/requirements.oai.txt b/requirements.oai.txt index 3d1156e..6ddc996 100644 --- a/requirements.oai.txt +++ b/requirements.oai.txt @@ -1 +1 @@ -guided-diffusion @ git+https://github.com/crowsonkb/guided-diffusion#egg=guided-diffusion \ No newline at end of file +guided-diffusion @ git+https://github.com/Birch-san/guided-diffusion.git@torch-sdp#egg=guided-diffusion \ No newline at end of file