This repo is built upon official StyleGAN2 pytorch repo. For detail please refer to stylegan2-ada-pytorch.
The main dependencies are:
- 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See https://pytorch.org/ for PyTorch install instructions.
- Cuda toolkit 11.0 or later.
- python libraries:
pip install -r docs/requirements.txt
Dataset | Resolution | download | FID (Inception) | FID (SwAV) |
---|---|---|---|---|
LSUN Horse | 256 | model | 2.11 | 0.71 |
LSUN Cat | 256 | model | 3.98 | 1.03 |
LSUN Church | 256 | model | 1.72 | 0.58 |
FFHQ | 1024 | model | 3.01 | 0.38 |
Dataset | Resolution | download | FID (Inception) | FID (SwAV) |
---|---|---|---|---|
AFHQ Dog | 512 | model | 4.73 | 1.04 |
AFHQ Cat | 512 | model | 2.53 | 0.62 |
AFHQ Wild | 512 | model | 2.36 | 1.10 |
AnimalFace-Dog | 256 | model | 32.56 | 6.47 |
AnimalFace-Cat | 256 | model | 27.35 | 5.18 |
100-shot Bridge-of-Sighs | 256 | model | 34.35 | 3.46 |
Other pre-trained models including experiments with varying training samples can be downloaded at this link.
To generate images:
# random image generation from LSUN Church model
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 --network=https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table2_fulldataset/vision-aided-gan-lsunchurch-ada-3.pkl
The above command generates 4 images using the provided seed values and saves it in out
directory controlled by --outdir
. Our generator architecture is same as styleGAN2 and can be similarly used in the Python code as described in stylegan2-ada-pytorch.
model evaluation:
python calc_metrics.py --network https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table2_fulldataset/vision-aided-gan-lsunchurch-ada-3.pkl --metrics fid50k_full --data lsunchurch --clean 1
We use clean-fid library to calculate FID metric. We calclate the full real distribution statistics for FID calculation. For details on calculating the statistics, please refer to clean-fid.
For default FID evaluation of StyleGAN2-ADA use clean=0
. The above command will return the FID ~1.72
worst sample analysis
python calc_metrics.py --metrics sort_likelihood --name afhq_dog --split train --network https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table3_afhq/vision-aided-gan-afhqdog-ada-3.pkl --data afhqdog
Example command to create similar visualization as shown here. The output image is saved in out
directory for the above command.
Dataset preparation is same as given in stylegan2-ada-pytorch.
Example 100-shot AnimalFace Dog dataset
mkdir datasets
wget https://data-efficient-gans.mit.edu/datasets/AnimalFace-dog.zip -P datasets
All datasets can be downloaded from their repsective websites:
FFHQ, LSUN Categories, AFHQ, AnimalFace Dog, AnimalFace Cat, 100-shot Bridge-of-Sighs
Vision-aided GAN training with multiple off-the-shelf models:
python vision-aided-gan.py --outdir models/ --data datasets/AnimalFace-dog.zip --cfg paper256_2fmap --mirror 1 \
--aug ada --augpipe bgc --augcv ada --batch 16 --gpus 2 --kimgs-list '1000,1000,1000' --num 3
The network, sample generated images, and logs are saved at regular intervals (controlled by --snap
flag) in <outdir>/<exp-folder>
dir, where <exp-folder>
name is based on input args. Network with each progressive additin of pretrained model is saved in a different directory. Logs are saved as TFevents by default. Wandb logging can be enabled by --wandb-log
flag and setting wandb entity
in training.training_loop
. If fine-tuning a baseline model with vision-aided adversarial loss include --resume <network.pkl>
in the above command.
--kimgs-list
controls the number of iterations after which next off-the-shelf model is added. It is a comma separated list of iteration numbers. For dataset with training samples 1k, we initialize --kimgs-list
to '4000,1000,1000', and for training samples >1k '8000,2000,2000'.
Vision-aided Gan training with a single off-the-shelf model
python train.py --outdir models/ --data datasets/AnimalFace-dog.zip --kimg 10000 --cfg paper256_2fmap --gpus 2 \
--cv input-clip-output-conv_multi_level --cv-loss multilevel_sigmoid_s --augcv ada --mirror 1 --aug ada --warmup 5e5
model selection: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.
python model_selection.py --data mydataset.zip --network <mynetworkfolder or mynetworkpklfile>
To add you own pretrained Model:
create the class file to extract pretrained features as vision_module/<custom_model>.py
. Add the class path in the class_name_dict
in vision_module.cvmodel.CVBackbone
class. Update the architecture of trainable classifier head over pretrained features in vision_module.cv_discriminator
.
Training configuration details
Training configuration corresponding to training with our loss:
--cv=input-<cv_type>-output-<output_type>
pretrained network and its configuration.--warmup=0
should be number of iterations after which vision-aided loss is added (~5e5) when training from scratch. Introduces our loss after training with warmup images of training.--cv-loss=multilevel_sigmoid_s
what loss to use on pretrained model based discriminator as described here.--augcv=ada
performs ADA augmentation on pretrained model based discriminator.--augcv=diffaugment-<policy>
performs DiffAugment on pretrained model based discriminator with given poilcy e.g.color,translation,cutout
--augpipecv=bgc
ADA augmentation strategy. Note: cutout is always enabled.--ada-target-cv=0.3
adjusts ADA target value for pretrained model based discriminator.--exact-resume=1
enables resume along with optimizer and augmentation state. default is 0.
StyleGAN2 configurations:
--outdir='models/'
directory to save training runs.--data
data directory created after runningdataset_tool.py
.--metrics=fid50kfull
evaluates FID calculation during training at everysnap
iterations.--cfg=paper256
architecture and hyperparameter configuration for G and D.--mirror=1
enables horizontal flipping--aug=ada
enables ADA augmentation in trainable D.--diffaugment=color,translation,cutout
enables DiffAugment in trainable D.--augpipe=bgc
ADA augmentation strategy in trainable D.--snap=25
evaluation and model saving interval
Miscellaneous configurations:
--wandb-log=1
enables wandb logging.--clean=1
enables FID calculation using clean-fid if the real distribution statistics are pre-calculated. default is False.
Run python train.py --help
for more details and the full list of args.