Skip to content

Latest commit

 

History

History
56 lines (41 loc) · 3.63 KB

README.md

File metadata and controls

56 lines (41 loc) · 3.63 KB

Getting Started with DMD2 on ImageNet-64x64

We trained ImageNet using mixed-precision in BF16 format, adapting the EDM's code to accommodate BF16 training (see LINK). We noticed that the training diverges if we use FP16. FP16 might work with some fancy loss scaling; help is greatly appreciated.

Model Zoo

Config Name FID Link Iters Hours
imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch 1.51 link 200k 53
imagenet_lr2e-6_scratch 2.61 link 410k 70
imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume* 1.28 link 140K 38

*The final model was resumed from the best checkpoint of the imagenet_lr2e-6_scratch run and trained for an additional 140,000 iterations.

For inference with our models, you only need to download the pytorch_model.bin file from the provided link. For fine-tuning, you will need to download the entire folder. You can use the following script for that:

export CHECKPOINT_NAME="imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch_fid1.51_checkpoint_model_193500"  # note that the imagenet/ is necessary
export OUTPUT_PATH="path/to/your/output/folder"

bash scripts/download_hf_checkpoint.sh $CHECKPOINT_NAME $OUTPUT_PATH

Download Base Diffusion Models and Training Data

export CHECKPOINT_PATH="" # change this to your own checkpoint folder 
export WANDB_ENTITY="" # change this to your own wandb entity
export WANDB_PROJECT="" # change this to your own wandb project

mkdir $CHECKPOINT_PATH

bash scripts/download_imagenet.sh $CHECKPOINT_PATH

You can also add these few export to the bashrc file so that you don't need to run them every time you open a new terminal.

Sample Training/Testing Commands

# start a training with 7 gpu
bash experiments/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh  $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT

# on the same node, start a testing process that continually reads from the checkpoint folder and evaluate the FID 
# Change TIMESTAMP_TBD to the real one
python main/edm/test_folder_edm.py \
    --folder $CHECKPOINT_PATH/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch/TIMESTAMP_TBD \
    --wandb_name test_imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch \
    --wandb_entity $WANDB_ENTITY \
    --wandb_project $WANDB_PROJECT \
    --resolution 64 --label_dim 1000  \
    --ref_path $CHECKPOINT_PATH/imagenet_fid_refs_edm.npz \
    --detector_url $CHECKPOINT_PATH/inception-2015-12-05.pkl 

Please refer to train_edm.py for various training options. Notably, if the --delete_ckpts flag is set to True, all checkpoints except the latest one will be deleted during training. Additionally, you can use the --cache_dir flag to specify a location with larger storage capacity. The number of checkpoints stored in cache_dir is controlled by the max_checkpoint argument.