This is the codebase for Expanding Expressiveness of Diffusion Models with Limited Data via Self-Distillation based Fine-Tuning.
This repository is based on openai/guided-diffusion and P2-Weighting.
SDFT aims to enhance the expressiveness of the diffusion models trained with limited datasets, which tend to have less diverse and biased attributes. The limited expressiveness not only hampers the generation capability of the model but also results in unsatisfactory outputs in various downstream tasks, such as domain translation and text-guided image manupulation.
All models are trained at 256x256 resolution.
We use pre-trained FFHQ model from P2-Weighting repository.
Here are the models trained on MetFaces with SDFT: link. We obtained the reported values with fine-tuning 10k iterations. The domain-specific feature extractor for EGSDE is also attached. We follow the official implementation of EGSDE to train the domain-specific feature extractor.
We trained the model on PyTorch 1.7.1, 8 RTX 2080 Ti GPUs.
First, set PYTHONPATH variable to point to the root of the repository. Do the same when training new models.
export PYTHONPATH=$PYTHONPATH:$(pwd)
Put model checkpoints into a folder ./models/
.
Samples will be saved in ./samples/
.
python scripts/image_sample.py --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing ddim40 --use_ddim True --model_path models/metface_distill.pt --sample_dir samples
We adopt 40 step DDIM sampling for the default config for the efficiency. One can change the sampling strategy by modifying --timestep_respacing
and --use_ddim
.
We implement the SDEdit in ./notebooks/SDEdit.ipynb
and EGSDE in in ./notebooks/EGSDE.ipynb
for diffusion-based domain translation. For EGSDE, please prepare the face2portrain.pt
in ./models/
for domain-specific feature extractor.
We follow Asyrp for the implementation of text-guided image manipulation.
For MetFaces dataset,
- set
--distill_lambda=0.1
and--distill_p2_gamma=3
for distillation loss in equation (3) of the paper. - set
--distill_agnostic=True
,--distill_agnostic_lambda=0.1
and--distill_agnostic_gamma=50
for auxiliary loss in equation (5) of the paper.
Logs and models will be saved in logs/
. You should modify --data_dir
.
bash train_ddp_distill.sh