Keras implementation of GMCNN (Generative Multi-column Convolutional Neural Networks) inpainting model originally proposed at NIPS 2018: Image Inpainting via Generative Multi-column Convolutional Neural Networks
- Code from this repository was tested on Python 3.6 and Ubuntu 14.04
- All required dependencies are stored in requirements.txt, requirements-cpu.txt and requirements-gpu.txt files.
Code download:
git clone https://github.com/tlatkowski/inpainting-gmcnn-keras.git
cd inpainting-gmcnn-keras
To install requirements, create Python virtual environment and install dependencies from files:
virtualenv -p /usr/bin/python3.6 .venv
source .venv/bin/activate
pip install -r requirements/requirements.txt
In case of using GPU support:
pip install -r requirements/requirements-gpu.txt
Otherwise (CPU usage):
pip install -r requirements/requirements-cpu.txt
Model was trained with usage of high-resolution images from Places365-Standard dataset. It can be found here
The mask dataset used for model training comes from NVIDIA's paper: Image Inpainting for Irregular Holes Using Partial Convolutions
NVIDIA's mask dataset is available here
Please note that the model training was performed on testing irregular mask dataset containing 12,000 masks.
./samples folder contains exemplary structure of dataset catalogs:
samples
|-masks
|-nvidia_masks
|-images
|-places365
nvidia_masks catalog contains 5 sample masks from NVIDIA's test set.
places365 catalog contains 5 sample images form Places365 validation set.
The main configuration file is placed in ./config/main_config.ini. It contains training and model parameters. You can tweak those parameters before model running.
The default configuration looks as follows:
[TRAINING]
WGAN_TRAINING_RATIO = 5
NUM_EPOCHS = 5
BATCH_SIZE = 4
IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_CHANNELS = 3
LEARNING_RATE = 0.0001
SAVE_MODEL_STEPS_PERIOD = 1000
[MODEL]
ADD_MASK_AS_GENERATOR_INPUT = False
GRADIENT_PENALTY_LOSS_WEIGHT = 10
ID_MRF_LOSS_WEIGHT = 0.05
ADVERSARIAL_LOSS_WEIGHT = 0.001
NN_STRETCH_SIGMA = 0.5
VGG_16_LAYERS = 3,6,10
ID_MRF_STYLE_WEIGHT = 1.0
ID_MRF_CONTENT_WEIGHT = 1.0
NUM_GAUSSIAN_STEPS = 3
GAUSSIAN_KERNEL_SIZE = 32
GAUSSIAN_KERNEL_STD = 40.0
After the dependencies installation you can perform training dry-run using image and mask samples provided in samples directory. To do so, execute the following command:
NOTE: Set BATCH_SIZE to 1 before executing the below command.
python runner.py --train_path ./samples/images --mask_path ./samples/masks --experiment_name "dry-run-test"
If everything goes correct you should be able to see the progress bar logging the basic training metrics.
In order to run GMCNN model training on your training data you have to provide paths to your datasets:
python runner.py --train_path /path/to/training/images --mask_path /path/to/mask/images --experiment_name "experiment_name"
According to the best practices of the usage of GAN frameworks, first we should train the generator model for a while. In order to train the generator only in the first line run the following command (additional flag warm_up_generator is set):
python runner.py --train_path /path/to/training/images --mask_path /path/to/mask/images -warm_up_generator
In this mode the generator will be trained with only confidence-driven reconstruction loss.
Below picture presents GMCNN outcome after 5 epochs training in warm-up generator mode
In order to continue training with full WGAN-GP framework (GMCNN generator, local and global discriminators), execute:
python runner.py --train_path /path/to/training/images --mask_path /path/to/mask/images --experiment_name "experiment_name" -from_weights
Running training with additional from_weights flag will force pipeline to load the latest models checkpoints from ./outputs/weights/ directory.
If you don't have an access to workstation with GPU, you can use the below exemplary Google Colab notebook for training your GMCNN model on Places365 validation data and NVIDIA's testing mask with usage of K80 GPU available within Google Colab backend: GMCNN in Google Colab
During the training procedure the pipeline logs additional results to the outputs directory:
- outputs/experiment_name/logs contains TensorBoard logs
- outputs/experiment_name/predicted_pics/warm_up_generator contains the model predictions for the specific steps in the warm up generator training mode
- outputs/experiment_name/predicted_pics/wgan contains the model predictions for the specific steps in the WGAN-GP training mode
- outputs/experiment_name/weights contains the generator and critics models weights
- outputs/experiment_name/summaries contains the generator and critics models summaries
You can track the metrics during the training with usage of TensorBoard:
tensorboard --logdir=./outputs/experiment_name/logs
- This model is trained using NVIDIA's irregular mask test set whereas the original model is trained using randomly generated rectangle masks.
- The current version of pipeline uses the higher-order features extracted from VGG16 model whereas the original model utilizes VGG19.
Below you can find the visualization of applying Gaussian blur to the training masks for the different number of convolution steps (number of iteration steps over the input raw mask).
Original | 1 step | 2 steps | 3 steps | 4 steps | 5 steps | 10 steps |
---|---|---|---|---|---|---|
Original | 1 step | 2 steps | 3 steps | 4 steps | 5 steps | 10 steps |
---|---|---|---|---|---|---|
Original | 1 step | 2 steps | 3 steps | 4 steps | 5 steps | 10 steps |
---|---|---|---|---|---|---|
After activating TensorBoard you can monitor the following training metrics:
- For the generator: confidence reconstruction loss, global wasserstein loss, local wasserstein loss, id mrf loss and total loss
- For the local and global discriminators: fake loss, real loss, gradient penalty loss and total loss
- ID-MRF loss function was implemented with usage of original Tensorflow implementation: GMCNN in Tensorflow
- Improved Wasserstain GAN was implemented based on: Wasserstein GAN with gradient penalty in Keras
- Model architecture diagram was done with usage of PlotNeuralNet: PlotNeuralNet on GitHub