Skip to content

Latest commit

 

History

History
83 lines (59 loc) · 4.94 KB

README.md

File metadata and controls

83 lines (59 loc) · 4.94 KB

If you like our project, please give us a star ⭐ on GitHub for the latest update.

hf arXiv License Hits GitHub issues GitHub closed issues

Class-Conditional Image Generation with MoH-DiT

💡 Download URL

🛠️ Requirements and Installation

Requirements

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Sampling

If you've trained a new MoH-DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 MoH-DiT-XL/2-90 model, run:

python sample.py --model MoH-DiT-XL/2-90 --image-size 256 --ckpt /path/to/model.pt

Training MoH-DiT

We provide a training script for MoH-DiT in train.py. This script can be used to train class-conditional MoH-DiT models, but it can be easily modified to support other types of conditioning. To launch MoH-DiT-XL/2-90 (256x256) training with 8 GPUs on one node:

torchrun --nnodes=1 \
--nproc_per_node=8 train.py \
--model MoH-DiT-XL/2-90 \
--data-path /path/to/imagenet/train \
--results-dir results/MoH-DiT-XL-2-90

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a MoH-DiT model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained MoH-DiT-XL/2-90 model over 8 GPUs, run:

torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --model MoH-DiT-XL/2-90 --num-fid-samples 50000

There are several additional options; see sample_ddp.py for details.