Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create LO-WGAN-gp.py #121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions implementations/LO-WGAN-gp/lo-wgan-gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

os.makedirs("images", exist_ok=True)


parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--latent_method", type=str, default="ngd", help="The latent optimization method")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)

def forward(self, img):
img_flat = img.view(img.shape[0], -1)
validity = self.model(img_flat)
return validity


# Loss weight for gradient penalty
lambda_gp = 10

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
generator.cuda()
discriminator.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

#Latent optimization

def latent_opt(Gen, Dis, z, method , batch_size, alpha= 0.9, beta= 0.1):
method = method.lower()

#Using gradient descent
if method == "gd":

fake = Gen(z)
f_z = Dis(dake.view(batch_size, opt.channels, opt.img_size, opt.img_size))

d_fz = torch.autograd.grad(outputs=f_z,
inputs= z,
grad_outputs=torch.ones_like(f_z),
retain_graph=True,
create_graph= True
)[0]

delta_z = torch.ones_like(d_fz)
delta_z = alpha * d_fz

with torch.no_grad():
z_prime = torch.clamp(z + delta_z, min=-1, max=1)

return z_prime
#Using natural gradient descent
elif method == "ngd":
fake = Gen(z)
f_z = Dis(fake.view(batch_size, opt.channels, opt.img_size, opt.img_size))

d_fz = torch.autograd.grad(outputs=f_z,
inputs= z,
grad_outputs=torch.ones_like(f_z),
retain_graph=True,
create_graph= True
)[0]

delta_z = torch.ones_like(d_fz)
delta_z = (alpha * d_fz) / (beta + torch.norm(delta_z, p=2, dim=0))
with torch.no_grad():
z_prime = torch.clamp(z + delta_z, min=-1, max=1)

return z_prime

def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty



# ----------
# Training
# ----------

batches_done = 0
for epoch in range(opt.n_epochs):

for i, (imgs, _) in enumerate(dataloader):

# Configure input
real_imgs = Variable(imgs.type(Tensor))

# ---------------------
# latent optimization step
# ---------------------

optimizer_G.zero_grad(), optimizer_D.zero_grad()

# Sample optimized noise as an input for the generator
z = Variable(Tensor(np.random.uniform(-1, 1, (imgs.shape[0], opt.latent_dim))), requires_grad=True)
z_prime = latent_opt(generator, discriminator, opt.latent_method, z, imgs.shape[0])

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()


# Generate a batch of images
fake_imgs = generator(z_prime)

# Real images
real_validity = discriminator(real_imgs)
# Fake images
fake_validity = discriminator(fake_imgs)
# Gradient penalty
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

d_loss.backward()
optimizer_D.step()

optimizer_G.zero_grad()

# Train the generator every n_critic steps
if i % opt.n_critic == 0:


# -----------------
# Train Generator
# -----------------

# Generate a batch of images
fake_imgs = generator(z_prime)
# Loss measures generator's ability to fool the discriminator
# Train on fake images
fake_validity = discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)

g_loss.backward()
optimizer_G.step()

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)

if batches_done % opt.sample_interval == 0:
save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

batches_done += opt.n_critic