diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..b701617 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,19 @@ +{ + "configurations": [ + { + "name": "Python: Compute FID OpenAI", + "type": "python", + "request": "launch", + "program": "train.py", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--config", "configs/config_guided_diffusion_imagenet.json", + "--resume-inference", "/nvme1/ml-weights/256x256_diffusion.pt", + "--evaluate-only", + "--evaluate-n", "4", + "--start-method", "fork" + ] + } + ] +} \ No newline at end of file diff --git a/configs/config_guided_diffusion_imagenet.json b/configs/config_guided_diffusion_imagenet.json new file mode 100644 index 0000000..768588c --- /dev/null +++ b/configs/config_guided_diffusion_imagenet.json @@ -0,0 +1,27 @@ +{ + "model": { + "type": "guided_diffusion", + "config": { + "attention_resolutions": "32, 16, 8", + "class_cond": true, + "diffusion_steps": 1000, + "image_size": 256, + "learn_sigma": true, + "noise_schedule": "linear", + "num_channels": 256, + "num_head_channels": 64, + "num_res_blocks": 2, + "resblock_updown": true, + "use_fp16": true, + "use_scale_shift_norm": true, + "use_torch_sdp_attention": true + }, + "input_channels": 3, + "input_size": [256, 256] + }, + "dataset": { + "type": "imagefolder-class", + "location": "/nvme1/ml-data/ImageNet/train", + "num_classes": 1000 + } +} \ No newline at end of file diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 79b51ce..dd0649c 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -167,6 +167,25 @@ def forward(self, input, sigma, **kwargs): return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip +class OpenAIVDenoiser(DiscreteVDDPMDenoiser): + """A wrapper for OpenAI v objective diffusion models.""" + + def __init__( + self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu" + ): + alphas_cumprod = torch.tensor( + diffusion.alphas_cumprod, device=device, dtype=torch.float32 + ) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_v(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + + class CompVisVDenoiser(DiscreteVDDPMDenoiser): """A wrapper for CompVis diffusion models that output v.""" diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 946f2da..6f3f7e5 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -336,6 +336,11 @@ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('in return u.logit().mul(scale).add(loc).exp().to(dtype) +def rand_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): + """Draws samples from a uniform distribution.""" + return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value) + + def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): """Draws samples from an log-uniform distribution.""" min_value = math.log(min_value) diff --git a/kdiff_trainer/load_diffusion_model.py b/kdiff_trainer/load_diffusion_model.py new file mode 100644 index 0000000..0f4c6f5 --- /dev/null +++ b/kdiff_trainer/load_diffusion_model.py @@ -0,0 +1,44 @@ +from pathlib import Path +import torch +import safetensors +from typing import Literal, Dict, NamedTuple +import k_diffusion as K + +from guided_diffusion import script_util +from guided_diffusion.unet import UNetModel +from guided_diffusion.respace import SpacedDiffusion + +class ModelAndDiffusion(NamedTuple): + model: UNetModel + diffusion: SpacedDiffusion + +def construct_diffusion_model(config_overrides: Dict = {}) -> ModelAndDiffusion: + model_config = script_util.model_and_diffusion_defaults() + if config_overrides: + model_config.update(config_overrides) + model, diffusion = script_util.create_model_and_diffusion(**model_config) + return ModelAndDiffusion(model, diffusion) + +def load_diffusion_model( + model_path: str, + model: UNetModel, +): + if Path(model_path).suffix == ".safetensors": + safetensors.torch.load_model(model, model_path) + else: + model.load_state_dict(torch.load(model_path, map_location="cpu")) + if model.dtype is torch.float16: + model.convert_to_fp16() + +def wrap_diffusion_model( + model: UNetModel, + diffusion: SpacedDiffusion, + device="cpu", + model_type: Literal['eps', 'v'] = "eps" +): + if model_type == "eps": + return K.external.OpenAIDenoiser(model, diffusion, device=device) + elif model_type == "v": + return K.external.OpenAIVDenoiser(model, diffusion, device=device) + else: + raise ValueError(f"Unknown model type {model_type}") \ No newline at end of file diff --git a/requirements.oai.txt b/requirements.oai.txt new file mode 100644 index 0000000..6ddc996 --- /dev/null +++ b/requirements.oai.txt @@ -0,0 +1 @@ +guided-diffusion @ git+https://github.com/Birch-san/guided-diffusion.git@torch-sdp#egg=guided-diffusion \ No newline at end of file diff --git a/train.py b/train.py index 6d85de1..e9ac0c2 100755 --- a/train.py +++ b/train.py @@ -18,9 +18,12 @@ from torch import distributed as dist from torch import multiprocessing as mp from torch import optim +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils import data from torchvision import datasets, transforms, utils from tqdm.auto import tqdm +from typing import Optional import k_diffusion as K @@ -104,12 +107,17 @@ def main(): except AttributeError: pass + do_train = not args.evaluate_only + config = K.config.load_config(args.config) model_config = config['model'] dataset_config = config['dataset'] - opt_config = config['optimizer'] - sched_config = config['lr_sched'] - ema_sched_config = config['ema_sched'] + if do_train: + opt_config = config['optimizer'] + sched_config = config['lr_sched'] + ema_sched_config = config['ema_sched'] + else: + opt_config = sched_config = ema_sched_config = None # TODO: allow non-square input sizes assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] @@ -131,15 +139,26 @@ def main(): demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) elapsed = 0.0 - inner_model = K.config.make_model(config) - inner_model_ema = deepcopy(inner_model) + if model_config['type'] == 'guided_diffusion': + from kdiff_trainer.load_diffusion_model import construct_diffusion_model + # can't easily put this into K.config.make_model; would change return type and introduce dependency + model_, guided_diff = construct_diffusion_model(model_config['config']) + else: + model_ = K.config.make_model(config) + guided_diff = None + + if do_train: + inner_model, inner_model_ema = model_, deepcopy(model_) + else: + inner_model, inner_model_ema = None, model_ + del model_ if args.compile: - inner_model.compile() + (inner_model or inner_model_ema).compile() # inner_model_ema.compile() if accelerator.is_main_process: - print(f'Parameters: {K.utils.n_params(inner_model):,}') + print(f'Parameters: {K.utils.n_params((inner_model or inner_model_ema)):,}') # If logging to wandb, initialize the run use_wandb = accelerator.is_main_process and args.wandb_project @@ -147,52 +166,57 @@ def main(): import wandb log_config = vars(args) log_config['config'] = config - log_config['parameters'] = K.utils.n_params(inner_model) + log_config['parameters'] = K.utils.n_params((inner_model or inner_model_ema)) wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) - lr = opt_config['lr'] if args.lr is None else args.lr - groups = inner_model.param_groups(lr) - if opt_config['type'] == 'adamw': - opt = optim.AdamW(groups, - lr=lr, - betas=tuple(opt_config['betas']), - eps=opt_config['eps'], - weight_decay=opt_config['weight_decay']) - elif opt_config['type'] == 'adam8bit': - import bitsandbytes as bnb - opt = bnb.optim.Adam8bit(groups, - lr=lr, - betas=tuple(opt_config['betas']), - eps=opt_config['eps'], - weight_decay=opt_config['weight_decay']) - elif opt_config['type'] == 'sgd': - opt = optim.SGD(groups, - lr=lr, - momentum=opt_config.get('momentum', 0.), - nesterov=opt_config.get('nesterov', False), - weight_decay=opt_config.get('weight_decay', 0.)) - else: - raise ValueError('Invalid optimizer type') - - if sched_config['type'] == 'inverse': - sched = K.utils.InverseLR(opt, - inv_gamma=sched_config['inv_gamma'], - power=sched_config['power'], - warmup=sched_config['warmup']) - elif sched_config['type'] == 'exponential': - sched = K.utils.ExponentialLR(opt, - num_steps=sched_config['num_steps'], - decay=sched_config['decay'], - warmup=sched_config['warmup']) - elif sched_config['type'] == 'constant': - sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) - else: - raise ValueError('Invalid schedule type') + if do_train: + lr = opt_config['lr'] if args.lr is None else args.lr + groups = inner_model.param_groups(lr) + if opt_config['type'] == 'adamw': + opt = optim.AdamW(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'adam8bit': + import bitsandbytes as bnb + opt = bnb.optim.Adam8bit(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'sgd': + opt = optim.SGD(groups, + lr=lr, + momentum=opt_config.get('momentum', 0.), + nesterov=opt_config.get('nesterov', False), + weight_decay=opt_config.get('weight_decay', 0.)) + else: + raise ValueError('Invalid optimizer type') + + if sched_config['type'] == 'inverse': + sched = K.utils.InverseLR(opt, + inv_gamma=sched_config['inv_gamma'], + power=sched_config['power'], + warmup=sched_config['warmup']) + elif sched_config['type'] == 'exponential': + sched = K.utils.ExponentialLR(opt, + num_steps=sched_config['num_steps'], + decay=sched_config['decay'], + warmup=sched_config['warmup']) + elif sched_config['type'] == 'constant': + sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) + else: + raise ValueError('Invalid schedule type') - assert ema_sched_config['type'] == 'inverse' - ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], - max_value=ema_sched_config['max_value']) - ema_stats = {} + assert ema_sched_config['type'] == 'inverse' + ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], + max_value=ema_sched_config['max_value']) + ema_stats = {} + else: + opt: Optional[Optimizer] = None + sched: Optional[LRScheduler] = None + ema_sched: Optional[K.utils.EMAWarmup] = None tf = transforms.Compose([ transforms.Resize(size[0], interpolation=transforms.InterpolationMode.BICUBIC), @@ -238,22 +262,39 @@ def main(): train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, persistent_workers=True, pin_memory=True) - inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) - if use_wandb: - wandb.watch(inner_model) + if do_train: + inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) + if use_wandb: + wandb.watch(inner_model) + else: + inner_model_ema, train_dl = accelerator.prepare(inner_model_ema, train_dl) + if accelerator.num_processes == 1: args.gns = False - if args.gns: + if args.gns and do_train: gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) gns_stats = K.gns.GradientNoiseScale() else: gns_stats = None - sigma_min = model_config['sigma_min'] - sigma_max = model_config['sigma_max'] - sample_density = K.config.make_sample_density(model_config) - - model = K.config.make_denoiser_wrapper(config)(inner_model) - model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) + if guided_diff is None: + sigma_min = model_config['sigma_min'] + sigma_max = model_config['sigma_max'] + if do_train: + sample_density = K.config.make_sample_density(model_config) + model = K.config.make_denoiser_wrapper(config)(inner_model) + model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) + class_cond_key = 'class_cond' + else: + from kdiff_trainer.load_diffusion_model import wrap_diffusion_model + model_ema = wrap_diffusion_model(inner_model_ema, guided_diff, device=accelerator.device) + sigma_min = model_ema.sigma_min.item() + sigma_max = model_ema.sigma_max.item() + if do_train: + sample_density = partial(K.utils.rand_uniform, min_value=0, max_value=guided_diff.num_timesteps-1) + model = wrap_diffusion_model(inner_model, guided_diff, device=accelerator.device) + else: + model_ema.requires_grad_(False).eval() + class_cond_key = 'y' state_path = Path(f'{args.name}_state.json') @@ -266,18 +307,19 @@ def main(): if accelerator.is_main_process: print(f'Resuming from {ckpt_path}...') ckpt = torch.load(ckpt_path, map_location='cpu') - unwrap(model.inner_model).load_state_dict(ckpt['model']) unwrap(model_ema.inner_model).load_state_dict(ckpt['model_ema']) - opt.load_state_dict(ckpt['opt']) - sched.load_state_dict(ckpt['sched']) - ema_sched.load_state_dict(ckpt['ema_sched']) - ema_stats = ckpt.get('ema_stats', ema_stats) - epoch = ckpt['epoch'] + 1 - step = ckpt['step'] + 1 - if args.gns and ckpt.get('gns_stats', None) is not None: - gns_stats.load_state_dict(ckpt['gns_stats']) - demo_gen.set_state(ckpt['demo_gen']) - elapsed = ckpt.get('elapsed', 0.0) + if do_train: + unwrap(model.inner_model).load_state_dict(ckpt['model']) + opt.load_state_dict(ckpt['opt']) + sched.load_state_dict(ckpt['sched']) + ema_sched.load_state_dict(ckpt['ema_sched']) + ema_stats = ckpt.get('ema_stats', ema_stats) + epoch = ckpt['epoch'] + 1 + step = ckpt['step'] + 1 + if args.gns and ckpt.get('gns_stats', None) is not None: + gns_stats.load_state_dict(ckpt['gns_stats']) + demo_gen.set_state(ckpt['demo_gen']) + elapsed = ckpt.get('elapsed', 0.0) del ckpt else: @@ -285,6 +327,8 @@ def main(): step = 0 if args.reset_ema: + if not do_train: + raise ValueError("Training is disabled (this can happen as a result of options such as --evaluate-only). Accordingly we did not construct a trainable model, and consequently cannot load the EMA model's weights onto said trainable model. Disable --reset-ema, or enable training.") unwrap(model.inner_model).load_state_dict(unwrap(model_ema.inner_model).state_dict()) ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], max_value=ema_sched_config['max_value']) @@ -293,12 +337,19 @@ def main(): if args.resume_inference: if accelerator.is_main_process: print(f'Loading {args.resume_inference}...') - ckpt = safetorch.load_file(args.resume_inference) - unwrap(model.inner_model).load_state_dict(ckpt) - unwrap(model_ema.inner_model).load_state_dict(ckpt) - del ckpt + if guided_diff is None: + ckpt = safetorch.load_file(args.resume_inference) + if do_train: + unwrap(model.inner_model).load_state_dict(ckpt) + unwrap(model_ema.inner_model).load_state_dict(ckpt) + del ckpt + else: + from kdiff_trainer.load_diffusion_model import load_diffusion_model + if do_train: + load_diffusion_model(args.resume_inference, unwrap(model.inner_model)) + load_diffusion_model(args.resume_inference, unwrap(model_ema.inner_model)) - evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 + evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 or args.evaluate_only metrics_log = None if evaluate_enabled: if args.evaluate_with == 'inception': @@ -332,7 +383,7 @@ def cfg_model_fn(x, sigma, class_cond): return cfg_model_fn return model - @torch.no_grad() + @torch.inference_mode() # note: inference_mode is lower-overhead than no_grad but disables forward-mode AD @K.utils.eval_mode(model_ema) def demo(): if accelerator.is_main_process: @@ -346,7 +397,7 @@ def demo(): if num_classes: class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device) dist.broadcast(class_cond, 0) - extra_args['class_cond'] = class_cond[accelerator.process_index] + extra_args[class_cond_key] = class_cond[accelerator.process_index] model_fn = make_cfg_model_fn(model_ema) sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=not accelerator.is_main_process) @@ -357,7 +408,7 @@ def demo(): if use_wandb: wandb.log({'demo_grid': wandb.Image(filename)}, step=step) - @torch.no_grad() + @torch.inference_mode() # note: inference_mode is lower-overhead than no_grad but disables forward-mode AD @K.utils.eval_mode(model_ema) def evaluate(): if not evaluate_enabled: @@ -369,7 +420,7 @@ def sample_fn(n): x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max model_fn, extra_args = model_ema, {} if num_classes: - extra_args['class_cond'] = torch.randint(0, num_classes, [n], device=device) + extra_args[class_cond_key] = torch.randint(0, num_classes, [n], device=device) model_fn = make_cfg_model_fn(model_ema) x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=True) return x_0 @@ -412,9 +463,11 @@ def save(): wandb.save(filename) if args.evaluate_only: - if not evaluate_enabled: - raise ValueError('--evaluate-only requested but evaluation is disabled') + if args.evaluate_n < 1: + raise ValueError('--evaluate-only requested but evaluate_n is less than 1') evaluate() + if accelerator.is_main_process: + tqdm.write('Finished evaluating!') return losses_since_last_print = [] @@ -437,7 +490,7 @@ def save(): class_cond = batch[class_key] drop = torch.rand(class_cond.shape, device=class_cond.device) class_cond.masked_fill_(drop < cond_dropout_rate, num_classes) - extra_args['class_cond'] = class_cond + extra_args[class_cond_key] = class_cond noise = torch.randn_like(reals) with K.utils.enable_stratified_accelerate(accelerator, disable=args.gns): sigma = sample_density([reals.shape[0]], device=device)