Skip to content

Latest commit

 

History

History
219 lines (172 loc) · 7.8 KB

README.md

File metadata and controls

219 lines (172 loc) · 7.8 KB

WGAN-GP-tensorflow

This repository is a Tensorflow implementation of the WGAN-GP for MNIST, CIFAR-10, and ImageNet64.

  • All samples in README.md are genearted by neural network except the first image for each row.

Install Prerequisites

  • python 3.5, 3.6 or 3.7
  • python3-tk

Ubuntu/Debian/etc.:

sudo apt install python3.5 python3.5-tk

Create Virtual Environment

python -m venv venv

Activate Virtual Environment

Windows:

venv/Scripts/activate

Bash:

source venv/bin/activate

Install Virtual Environment Requirements

pip install -r requirements.d/venv.txt

Create Execution Environments

tox --notest

That will install tensorflow which uses only the CPU.

To use an Nvidia GPU:

.tox/py35/bin/python -m pip uninstall tensorflow
.tox/py35/bin/python -m pip install tensorflow-gpu==1.13.1
.tox/py36/bin/python -m pip uninstall tensorflow
.tox/py36/bin/python -m pip install tensorflow-gpu==1.13.1
.tox/py37/bin/python -m pip uninstall tensorflow
.tox/py37/bin/python -m pip install tensorflow-gpu==1.13.1

To use an AMD GPU:

.tox/py35/bin/python -m pip uninstall tensorflow
.tox/py35/bin/python -m pip install tensorflow-rocm==1.13.1
.tox/py36/bin/python -m pip uninstall tensorflow
.tox/py36/bin/python -m pip install tensorflow-rocm==1.13.1
.tox/py36/bin/python -m pip uninstall tensorflow
.tox/py37/bin/python -m pip install tensorflow-rocm==1.13.1

Generated Images

1. Toy Dataset

Results from 2-dimensional of the 8 Gaussian Mixture Models, 25 Gaussian Mixture Models, and Swiss Roll data. Ipython Notebook.

Note: To demonstrate following experiment, we held the generator distribution Pg fixed at the real distribution plus unit-variance Gaussian noise.

  • Top: GAN discriminator
  • Middle: WGAN critic with weight clipping
  • Bottom: WGAN critic with weight penalty

Note: For the next experiment, we did not fix generator and showed generated points by the generator.

  • Top: GAN discriminator
  • Middle: WGAN critic with weight clipping
  • Bottom: WGAN critic with weight penalty

2. MNIST Dataset

3. CIFAR-10

4. IMAGENET64

Documentation

Download Dataset

'MNIST' and 'CIFAR10' dataset will be downloaded automatically from the code if in a specific folder there are no dataset. 'ImageNet64' dataset can be download from the Downsampled ImageNet.

Directory Hierarchy

.
│   WGAN-GP
│   ├── src
│   │   ├── imagenet (folder saved inception network weights that downloaded from the inception_score.py)
│   │   ├── cache.py
│   │   ├── cifar10.py
│   │   ├── dataset.py
│   │   ├── dataset_.py
│   │   ├── download.py
│   │   ├── inception_score.py
│   │   ├── main.py
│   │   ├── plot.py
│   │   ├── solver.py
│   │   ├── tensorflow_utils.py
│   │   ├── utils.py
│   │   └── wgan_gp.py
│   Data
│   ├── mnist
│   ├── cifar10
│   └── imagenet64

src: source codes of the WGAN-GP

Training WGAN-GP

Use main.py to train a WGAN-GP network. Example usage:

python main.py
  • gpu_index: gpu index, default: 0

  • batch_size: batch size for one feed forward, default: 64

  • dataset: dataset name from [mnist, cifar10, imagenet64], default: mnist

  • is_train: training or inference mode, default: True

  • learning_rate: initial learning rate for Adam, default: 0.001

  • num_critic: the number of iterations of the critic per generator iteration, default: 5

  • z_dim: dimension of z vector, default: 128

  • lambda_: gradient penalty lambda hyperparameter, default: 10.

  • beta1: beta1 momentum term of Adam, default: 0.5

  • beta2: beta2 momentum term of Adam, default: 0.9

  • iters: number of interations, default: 200000

  • print_freq: print frequency for loss, default: 100

  • save_freq: save frequency for model, default: 10000

  • sample_freq: sample frequency for saving image, default: 500

  • inception_freq: calculation frequence of the inception score, default: 1000

  • sample_batch: number of sampling images for check generator quality, default: 64

  • load_model: folder of save model that you wish to test, (e.g. 20181120-1558). default: None

WGAN-GP During Training

Note: From the following figures, the Y axises are tge negative critic loss for the WGAN-GP.

  1. MNIST

  1. CIFAR10

  1. IMAGENET64

Inception Score on CIFAR10 During Training

Note: Inception score was calculated every 1000 iterations.

Test WGAN-GP

Use main.py to test a WGAN-GP network. Example usage:

python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20181120-1558

Please refer to the above arguments.

Citation

  @misc{chengbinjin2018wgan-gp,
    author = {Cheng-Bin Jin},
    title = {WGAN-GP-tensorflow},
    year = {2018},
    howpublished = {\url{https://github.com/ChengBinJin/WGAN-GP-tensorflow}},
    note = {commit xxxxxxx}
  }

Attributions/Thanks

License

Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: [email protected]). Free for research use, as long as proper attribution is given and this copyright notice is retained.

Related Projects