Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemelas committed Apr 7, 2021
0 parents commit 009df37
Show file tree
Hide file tree
Showing 127 changed files with 19,530 additions and 0 deletions.
137 changes: 137 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Custom
.vscode
wandb
outputs
tmp*
slurm-logs
inversion/gans/BigGAN/weights

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.github

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Lightning /research
test_tube_exp/
tests/tests_tt_dir/
tests/save_dir
default/
data/
test_tube_logs/
test_tube_data/
datasets/
model_weights/
tests/save_dir
tests/tests_tt_dir/
processed/
raw/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# IDEs
.idea
.vscode

# seed project
lightning_logs/
MNIST
.DS_Store
145 changes: 145 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
<div align="center">

## PyTorch Pretrained GANs
<!-- [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539) -->
<!-- [![Conference](http://img.shields.io/badge/CVPR-2021-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) -->

</div>

<!-- TODO: Add video -->

### Quick Start
This repository provides a standardized interface for pretrained GANs in PyTorch. You can install it with:
```bash
pip install git+https://github.com/lukemelas/pytorch-pretrained-gans
```
It is then easy to generate an image with a GAN:
```python
import torch
from pytorch_pretrained_gans import make_gan

# Sample a class-conditional image from BigGAN with default resolution 256
G = make_gan(gan_type='biggan') # -> nn.Module
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000])
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128])
x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256])
```

### Motivation
Over the past few years, great progress has been made in generative modeling using GANs. As a result, a large body of research has emerged that uses GANs and explores/interprets their latent spaces. I recently worked on a project in which I wanted to apply the same technique to a bunch of different GANs (here's the [paper](https://github.com/lukemelas/unsupervised-image-segmentation) if you're interested). This was quite a pain because all the pretrained GANs out there are in completely different formats. So I decided to standardize them, and here's the result. I hope you find it useful.

### Installation
Install with `pip` directly from GitHub:
```
pip install git+https://github.com/lukemelas/pytorch-pretrained-gans
```

### Available GANs

The following GANs are available. If you would like to add a new GAN to the repo, please submit a pull request -- I would love to add to this list:
- [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch)
- [BigBiGAN](https://arxiv.org/abs/1907.02544)
- [StyleGAN-2-ADA](https://arxiv.org/abs/1912.04958)
- [Self-Conditioned GANs](https://arxiv.org/abs/2006.10728)
- [Many GANs from StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN):
- [BigGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) (a reimplementation)
- [ContraGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN)
- [SAGAN](https://arxiv.org/abs/1805.08318)
- [SNGAN](https://arxiv.org/abs/1802.05957)



### Structure

This repo supports both conditional and unconditional GANs. The standard GAN interface is as follows:

```python
class GeneratorWrapper(torch.nn.Module):
""" A wrapper to put the GAN in a standard format."""

def __init__(self, G, num_classes=None):
super().__init__()
self.G : nn.Module = # GAN generator
self.dim_z : int = # dimensionality of latent space
self.conditional = # True / False

def forward(self, z, y=None): # y is for conditional GAN only
x = # ... generate image from latent with self.G
return x # returns image

def sample_latent(self, batch_size, device='cpu'):
z = # ... samples latent vector of size self.dim_z
return z

def sample_class(self, batch_size=None, device='cpu'):
y = # ... samples class y (for conditional GAN only)
return y
```

Each type of GAN is contained in its own folder and has a `make_GAN_TYPE` function. For example, `make_bigbigan` creates a BigBiGAN with the format of the `GeneratorWrapper` above.

The weights of all GANs except those in PyTorch-StudioGAN and are downloaded automatically. To download the PyTorch-StudioGAN weights, use the `download.sh` scripts in the corresponding folders (see the file structure below).

#### Code Structure
The structure of the repo is below. Each type of GAN has an `__init__.py` file that defines its `GeneratorWrapper` and its `make_GAN_TYPE` file.

```bash
pytorch_pretrained_gans
├── __init__.py
├── BigBiGAN
│   ├── __init__.py
│   ├── ...
│   └── weights
│      └── download.sh # (use this to download pretrained weights)
├── BigGAN
│   ├── __init__.py
│   ├── ...
├── StudioGAN
│   ├── __init__.py
│   ├── ...
│   ├── configs
│   │   ├── ImageNet
│   │   │   ├── BigGAN2048
│   │   │   │   └── ...
│   │   │   └── download.sh # (use this to download pretrained weights)
│   │   └── TinyImageNet
│   │   ├── ACGAN
│   │   │   └── ACGAN.json
│   │   ├── ...
│   │   └── download.sh # (use this to download pretrained weights)
├── self_conditioned
│   ├── __init__.py
│   └── ...
└── stylegan2_ada_pytorch
   ├── __init__.py
   └── ...
```

### GAN-Specific Details

Naturally, there are some details that are specific to certain GANs.

**BigGAN:** For BigGAN, you should specify a resolution with `model_name`. For example:
* `G = make_gan(gan_type='biggan', model_name='biggan-deep-512')`

**StudioGAN:** For StudioGAN, you should specify a model with `model_name`. For example:
* `G = make_gan(gan_type='studiogan', model_name='SAGAN')`
* `G = make_gan(gan_type='studiogan', model_name='ContraGAN256')`

**Self-Conditioned GAN:** For StudioGAN, you should specify a model (either `self_conditioned` or `unconditional`) with `model_name`. For example:
* `G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned')`

**StyleGAN 2:**
* StyleGAN2's `sample_latent` method returns `w`, not `z`, because this is usually what is desired. `w` has shape `torch.Size([1, 18, 512])`.
* StyleGAN2 is currently not implemented on `CPU`

### Citation
Please cite the following if you use this repo in a research paper:
```bibtex
@inproceedings{melaskyriazi2021finding,
author = {Melas-Kyriazi, Luke and Manrai, Arjun},
title = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
booktitle = arxiv,
year = {2021}
}
```
34 changes: 34 additions & 0 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Example usage:
python generate.py
"""
import torch
from pytorch_pretrained_gans import make_gan

# BigGAN (unconditional)
G = make_gan(gan_type='biggan', model_type='biggan-deep-512') # -> nn.Module
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000])
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128])
x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256])

# BigBiGAN (unconditional)
G = make_gan(gan_type='bigbigan') # -> nn.Module
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120])
x = G(z=z) # -> torch.Size([1, 3, 128, 128])

# Self-Conditioned GAN (unconditional)
G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned') # -> nn.Module
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000])
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128])
x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128])

# StudioGAN (unconditional)
G = make_gan(gan_type='studiogan', model_name='SAGAN') # -> nn.Module
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000])
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128])
x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128])

# StyleGAN2 (unconditional)
G = make_gan(gan_type='stylegan2') # -> nn.Module
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120])
x = G(z=z) # -> torch.Size([1, 3, 128, 128])
4 changes: 4 additions & 0 deletions pytorch_pretrained_gans/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.pt
*.pth
*.pkl
*.npy
1 change: 1 addition & 0 deletions pytorch_pretrained_gans/BigBiGAN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gan_load import make_big_gan
65 changes: 65 additions & 0 deletions pytorch_pretrained_gans/BigBiGAN/gan_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from pathlib import Path
import torch
from torch import nn
from .model import BigGAN
from .gan_with_shift import gan_with_shift

DEFAULT_WEIGHTS_ROOT = Path(__file__).parent / 'weights/BigBiGAN_x1.pth'


class GeneratorWrapper(nn.Module):
""" A wrapper to put the GAN in a standard format -- here, a modified
version of the old UnconditionalBigGAN class """

def __init__(self, big_gan):
super().__init__()
self.big_gan = big_gan
self.dim_z = self.big_gan.dim_z
self.conditional = False

def forward(self, z):
classes = torch.zeros(z.shape[0], dtype=torch.int64, device=z.device)
return self.big_gan(z, self.big_gan.shared(classes))

def sample_latent(self, batch_size, device):
z = torch.randn((batch_size, self.dim_z), device=device)
return z


def make_biggan_config(resolution):
attn_dict = {128: '64', 256: '128', 512: '64'}
dim_z_dict = {128: 120, 256: 140, 512: 128}
config = {
'G_param': 'SN', 'D_param': 'SN',
'G_ch': 96, 'D_ch': 96,
'D_wide': True, 'G_shared': True,
'shared_dim': 128, 'dim_z': dim_z_dict[resolution],
'hier': True, 'cross_replica': False,
'mybn': False, 'G_activation': nn.ReLU(inplace=True),
'G_attn': attn_dict[resolution],
'norm_style': 'bn',
'G_init': 'ortho', 'skip_init': True, 'no_optim': True,
'G_fp16': False, 'G_mixed_precision': False,
'accumulate_stats': False, 'num_standing_accumulations': 16,
'G_eval_mode': True,
'BN_eps': 1e-04, 'SN_eps': 1e-04,
'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution,
'n_classes': 1000}
return config


@gan_with_shift
def make_big_gan(weights_root, resolution):
config = make_biggan_config(resolution)
G = BigGAN.Generator(**config)
G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False)
return GeneratorWrapper(G)


def make_bigbigan(model_name='bigbigan-128', weights_root=DEFAULT_WEIGHTS_ROOT):
assert model_name == 'bigbigan-128'
config = make_biggan_config(resolution=128)
G = BigGAN.Generator(**config)
G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False)
G = GeneratorWrapper(G)
return G
Loading

0 comments on commit 009df37

Please sign in to comment.