-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 009df37
Showing
127 changed files
with
19,530 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Custom | ||
.vscode | ||
wandb | ||
outputs | ||
tmp* | ||
slurm-logs | ||
inversion/gans/BigGAN/weights | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
.github | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# Lightning /research | ||
test_tube_exp/ | ||
tests/tests_tt_dir/ | ||
tests/save_dir | ||
default/ | ||
data/ | ||
test_tube_logs/ | ||
test_tube_data/ | ||
datasets/ | ||
model_weights/ | ||
tests/save_dir | ||
tests/tests_tt_dir/ | ||
processed/ | ||
raw/ | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# IDEs | ||
.idea | ||
.vscode | ||
|
||
# seed project | ||
lightning_logs/ | ||
MNIST | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
<div align="center"> | ||
|
||
## PyTorch Pretrained GANs | ||
<!-- [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539) --> | ||
<!-- [![Conference](http://img.shields.io/badge/CVPR-2021-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) --> | ||
|
||
</div> | ||
|
||
<!-- TODO: Add video --> | ||
|
||
### Quick Start | ||
This repository provides a standardized interface for pretrained GANs in PyTorch. You can install it with: | ||
```bash | ||
pip install git+https://github.com/lukemelas/pytorch-pretrained-gans | ||
``` | ||
It is then easy to generate an image with a GAN: | ||
```python | ||
import torch | ||
from pytorch_pretrained_gans import make_gan | ||
|
||
# Sample a class-conditional image from BigGAN with default resolution 256 | ||
G = make_gan(gan_type='biggan') # -> nn.Module | ||
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) | ||
x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) | ||
``` | ||
|
||
### Motivation | ||
Over the past few years, great progress has been made in generative modeling using GANs. As a result, a large body of research has emerged that uses GANs and explores/interprets their latent spaces. I recently worked on a project in which I wanted to apply the same technique to a bunch of different GANs (here's the [paper](https://github.com/lukemelas/unsupervised-image-segmentation) if you're interested). This was quite a pain because all the pretrained GANs out there are in completely different formats. So I decided to standardize them, and here's the result. I hope you find it useful. | ||
|
||
### Installation | ||
Install with `pip` directly from GitHub: | ||
``` | ||
pip install git+https://github.com/lukemelas/pytorch-pretrained-gans | ||
``` | ||
|
||
### Available GANs | ||
|
||
The following GANs are available. If you would like to add a new GAN to the repo, please submit a pull request -- I would love to add to this list: | ||
- [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch) | ||
- [BigBiGAN](https://arxiv.org/abs/1907.02544) | ||
- [StyleGAN-2-ADA](https://arxiv.org/abs/1912.04958) | ||
- [Self-Conditioned GANs](https://arxiv.org/abs/2006.10728) | ||
- [Many GANs from StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN): | ||
- [BigGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) (a reimplementation) | ||
- [ContraGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) | ||
- [SAGAN](https://arxiv.org/abs/1805.08318) | ||
- [SNGAN](https://arxiv.org/abs/1802.05957) | ||
|
||
|
||
|
||
### Structure | ||
|
||
This repo supports both conditional and unconditional GANs. The standard GAN interface is as follows: | ||
|
||
```python | ||
class GeneratorWrapper(torch.nn.Module): | ||
""" A wrapper to put the GAN in a standard format.""" | ||
|
||
def __init__(self, G, num_classes=None): | ||
super().__init__() | ||
self.G : nn.Module = # GAN generator | ||
self.dim_z : int = # dimensionality of latent space | ||
self.conditional = # True / False | ||
|
||
def forward(self, z, y=None): # y is for conditional GAN only | ||
x = # ... generate image from latent with self.G | ||
return x # returns image | ||
|
||
def sample_latent(self, batch_size, device='cpu'): | ||
z = # ... samples latent vector of size self.dim_z | ||
return z | ||
|
||
def sample_class(self, batch_size=None, device='cpu'): | ||
y = # ... samples class y (for conditional GAN only) | ||
return y | ||
``` | ||
|
||
Each type of GAN is contained in its own folder and has a `make_GAN_TYPE` function. For example, `make_bigbigan` creates a BigBiGAN with the format of the `GeneratorWrapper` above. | ||
|
||
The weights of all GANs except those in PyTorch-StudioGAN and are downloaded automatically. To download the PyTorch-StudioGAN weights, use the `download.sh` scripts in the corresponding folders (see the file structure below). | ||
|
||
#### Code Structure | ||
The structure of the repo is below. Each type of GAN has an `__init__.py` file that defines its `GeneratorWrapper` and its `make_GAN_TYPE` file. | ||
|
||
```bash | ||
pytorch_pretrained_gans | ||
├── __init__.py | ||
├── BigBiGAN | ||
│ ├── __init__.py | ||
│ ├── ... | ||
│ └── weights | ||
│ └── download.sh # (use this to download pretrained weights) | ||
├── BigGAN | ||
│ ├── __init__.py | ||
│ ├── ... | ||
├── StudioGAN | ||
│ ├── __init__.py | ||
│ ├── ... | ||
│ ├── configs | ||
│ │ ├── ImageNet | ||
│ │ │ ├── BigGAN2048 | ||
│ │ │ │ └── ... | ||
│ │ │ └── download.sh # (use this to download pretrained weights) | ||
│ │ └── TinyImageNet | ||
│ │ ├── ACGAN | ||
│ │ │ └── ACGAN.json | ||
│ │ ├── ... | ||
│ │ └── download.sh # (use this to download pretrained weights) | ||
├── self_conditioned | ||
│ ├── __init__.py | ||
│ └── ... | ||
└── stylegan2_ada_pytorch | ||
├── __init__.py | ||
└── ... | ||
``` | ||
|
||
### GAN-Specific Details | ||
|
||
Naturally, there are some details that are specific to certain GANs. | ||
|
||
**BigGAN:** For BigGAN, you should specify a resolution with `model_name`. For example: | ||
* `G = make_gan(gan_type='biggan', model_name='biggan-deep-512')` | ||
|
||
**StudioGAN:** For StudioGAN, you should specify a model with `model_name`. For example: | ||
* `G = make_gan(gan_type='studiogan', model_name='SAGAN')` | ||
* `G = make_gan(gan_type='studiogan', model_name='ContraGAN256')` | ||
|
||
**Self-Conditioned GAN:** For StudioGAN, you should specify a model (either `self_conditioned` or `unconditional`) with `model_name`. For example: | ||
* `G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned')` | ||
|
||
**StyleGAN 2:** | ||
* StyleGAN2's `sample_latent` method returns `w`, not `z`, because this is usually what is desired. `w` has shape `torch.Size([1, 18, 512])`. | ||
* StyleGAN2 is currently not implemented on `CPU` | ||
|
||
### Citation | ||
Please cite the following if you use this repo in a research paper: | ||
```bibtex | ||
@inproceedings{melaskyriazi2021finding, | ||
author = {Melas-Kyriazi, Luke and Manrai, Arjun}, | ||
title = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models}, | ||
booktitle = arxiv, | ||
year = {2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Example usage: | ||
python generate.py | ||
""" | ||
import torch | ||
from pytorch_pretrained_gans import make_gan | ||
|
||
# BigGAN (unconditional) | ||
G = make_gan(gan_type='biggan', model_type='biggan-deep-512') # -> nn.Module | ||
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) | ||
x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) | ||
|
||
# BigBiGAN (unconditional) | ||
G = make_gan(gan_type='bigbigan') # -> nn.Module | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120]) | ||
x = G(z=z) # -> torch.Size([1, 3, 128, 128]) | ||
|
||
# Self-Conditioned GAN (unconditional) | ||
G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned') # -> nn.Module | ||
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) | ||
x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) | ||
|
||
# StudioGAN (unconditional) | ||
G = make_gan(gan_type='studiogan', model_name='SAGAN') # -> nn.Module | ||
y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) | ||
x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) | ||
|
||
# StyleGAN2 (unconditional) | ||
G = make_gan(gan_type='stylegan2') # -> nn.Module | ||
z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120]) | ||
x = G(z=z) # -> torch.Size([1, 3, 128, 128]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.pt | ||
*.pth | ||
*.pkl | ||
*.npy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gan_load import make_big_gan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from pathlib import Path | ||
import torch | ||
from torch import nn | ||
from .model import BigGAN | ||
from .gan_with_shift import gan_with_shift | ||
|
||
DEFAULT_WEIGHTS_ROOT = Path(__file__).parent / 'weights/BigBiGAN_x1.pth' | ||
|
||
|
||
class GeneratorWrapper(nn.Module): | ||
""" A wrapper to put the GAN in a standard format -- here, a modified | ||
version of the old UnconditionalBigGAN class """ | ||
|
||
def __init__(self, big_gan): | ||
super().__init__() | ||
self.big_gan = big_gan | ||
self.dim_z = self.big_gan.dim_z | ||
self.conditional = False | ||
|
||
def forward(self, z): | ||
classes = torch.zeros(z.shape[0], dtype=torch.int64, device=z.device) | ||
return self.big_gan(z, self.big_gan.shared(classes)) | ||
|
||
def sample_latent(self, batch_size, device): | ||
z = torch.randn((batch_size, self.dim_z), device=device) | ||
return z | ||
|
||
|
||
def make_biggan_config(resolution): | ||
attn_dict = {128: '64', 256: '128', 512: '64'} | ||
dim_z_dict = {128: 120, 256: 140, 512: 128} | ||
config = { | ||
'G_param': 'SN', 'D_param': 'SN', | ||
'G_ch': 96, 'D_ch': 96, | ||
'D_wide': True, 'G_shared': True, | ||
'shared_dim': 128, 'dim_z': dim_z_dict[resolution], | ||
'hier': True, 'cross_replica': False, | ||
'mybn': False, 'G_activation': nn.ReLU(inplace=True), | ||
'G_attn': attn_dict[resolution], | ||
'norm_style': 'bn', | ||
'G_init': 'ortho', 'skip_init': True, 'no_optim': True, | ||
'G_fp16': False, 'G_mixed_precision': False, | ||
'accumulate_stats': False, 'num_standing_accumulations': 16, | ||
'G_eval_mode': True, | ||
'BN_eps': 1e-04, 'SN_eps': 1e-04, | ||
'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, | ||
'n_classes': 1000} | ||
return config | ||
|
||
|
||
@gan_with_shift | ||
def make_big_gan(weights_root, resolution): | ||
config = make_biggan_config(resolution) | ||
G = BigGAN.Generator(**config) | ||
G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False) | ||
return GeneratorWrapper(G) | ||
|
||
|
||
def make_bigbigan(model_name='bigbigan-128', weights_root=DEFAULT_WEIGHTS_ROOT): | ||
assert model_name == 'bigbigan-128' | ||
config = make_biggan_config(resolution=128) | ||
G = BigGAN.Generator(**config) | ||
G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False) | ||
G = GeneratorWrapper(G) | ||
return G |
Oops, something went wrong.