From 009df3721cfa91edd3ac60c9c8335ea480e5fda4 Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 6 Apr 2021 20:39:27 -0400 Subject: [PATCH] init --- .gitignore | 137 +++ README.md | 145 +++ examples/generate.py | 34 + pytorch_pretrained_gans/.gitignore | 4 + pytorch_pretrained_gans/BigBiGAN/__init__.py | 1 + pytorch_pretrained_gans/BigBiGAN/gan_load.py | 65 ++ .../BigBiGAN/gan_with_shift.py | 20 + .../BigBiGAN/model/BigGAN.py | 454 ++++++++ .../BigBiGAN/model/__init__.py | 0 .../BigBiGAN/model/layers.py | 461 ++++++++ .../BigBiGAN/model/sync_batchnorm/__init__.py | 12 + .../model/sync_batchnorm/batchnorm.py | 349 ++++++ .../model/sync_batchnorm/batchnorm_reimpl.py | 74 ++ .../BigBiGAN/model/sync_batchnorm/comm.py | 137 +++ .../model/sync_batchnorm/replicate.py | 94 ++ .../BigBiGAN/model/sync_batchnorm/unittest.py | 29 + .../BigBiGAN/weights/download.sh | 1 + pytorch_pretrained_gans/BigGAN/__init__.py | 78 ++ pytorch_pretrained_gans/BigGAN/config.py | 70 ++ .../BigGAN/convert_tf_to_pytorch.py | 312 ++++++ pytorch_pretrained_gans/BigGAN/file_utils.py | 249 +++++ pytorch_pretrained_gans/BigGAN/model.py | 330 ++++++ pytorch_pretrained_gans/BigGAN/utils.py | 216 ++++ .../CIPS/GeneratorsCIPS.py | 206 ++++ pytorch_pretrained_gans/CIPS/__init__.py | 121 +++ pytorch_pretrained_gans/CIPS/blocks.py | 587 +++++++++++ pytorch_pretrained_gans/CIPS/op/__init__.py | 2 + pytorch_pretrained_gans/CIPS/op/fused_act.py | 86 ++ .../CIPS/op/fused_bias_act.cpp | 21 + .../CIPS/op/fused_bias_act_kernel.cu | 99 ++ pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp | 23 + pytorch_pretrained_gans/CIPS/op/upfirdn2d.py | 187 ++++ .../CIPS/op/upfirdn2d_kernel.cu | 272 +++++ pytorch_pretrained_gans/StudioGAN/__init__.py | 138 +++ .../ImageNet/BigGAN2048/BigGAN2048.json | 112 ++ .../configs/ImageNet/BigGAN256/BigGAN256.json | 112 ++ .../ImageNet/ContraGAN2048/ContraGAN2048.json | 111 ++ .../ImageNet/ContraGAN256/ContraGAN256.json | 112 ++ .../configs/ImageNet/SAGAN/SAGAN.json | 111 ++ .../configs/ImageNet/SNGAN/SNGAN.json | 111 ++ .../StudioGAN/configs/ImageNet/download.sh | 23 + .../configs/TinyImageNet/ACGAN/ACGAN.json | 112 ++ .../configs/TinyImageNet/BigGAN/BigGAN.json | 111 ++ .../TinyImageNet/ContraGAN/DiffAugGAN(C).json | 111 ++ .../configs/TinyImageNet/GGAN/GGAN.json | 111 ++ .../configs/TinyImageNet/LSGAN/LSGAN.json | 111 ++ .../configs/TinyImageNet/ProjGAN/ProjGAN.json | 111 ++ .../configs/TinyImageNet/SAGAN/SAGAN.json | 111 ++ .../configs/TinyImageNet/SNGAN/SNGAN.json | 111 ++ .../configs/TinyImageNet/WGAN-WC/WGAN-WC.json | 111 ++ .../configs/TinyImageNet/download.sh | 72 ++ pytorch_pretrained_gans/StudioGAN/loader.py | 298 ++++++ pytorch_pretrained_gans/StudioGAN/main.py | 141 +++ .../StudioGAN/models/big_resnet.py | 441 ++++++++ .../StudioGAN/models/big_resnet_deep.py | 382 +++++++ .../StudioGAN/models/resnet.py | 422 ++++++++ .../StudioGAN/sync_batchnorm/batchnorm.py | 421 ++++++++ .../sync_batchnorm/batchnorm_reimpl.py | 99 ++ .../StudioGAN/sync_batchnorm/comm.py | 162 +++ .../StudioGAN/sync_batchnorm/replicate.py | 119 +++ .../StudioGAN/sync_batchnorm/unittest.py | 54 + .../StudioGAN/utils/ada.py | 415 ++++++++ .../StudioGAN/utils/ada_op/__init__.py | 2 + .../StudioGAN/utils/ada_op/fused_act.py | 122 +++ .../StudioGAN/utils/ada_op/fused_bias_act.cpp | 46 + .../utils/ada_op/fused_bias_act_kernel.cu | 99 ++ .../StudioGAN/utils/ada_op/upfirdn2d.cpp | 48 + .../StudioGAN/utils/ada_op/upfirdn2d.py | 225 ++++ .../utils/ada_op/upfirdn2d_kernel.cu | 369 +++++++ .../StudioGAN/utils/biggan_utils.py | 105 ++ .../StudioGAN/utils/cr_diff_aug.py | 50 + .../StudioGAN/utils/diff_aug.py | 105 ++ .../StudioGAN/utils/load_checkpoint.py | 38 + .../StudioGAN/utils/log.py | 54 + .../StudioGAN/utils/losses.py | 316 ++++++ .../StudioGAN/utils/make_hdf5.py | 93 ++ .../StudioGAN/utils/misc.py | 601 +++++++++++ .../StudioGAN/utils/model_ops.py | 170 +++ .../StudioGAN/utils/sample.py | 114 ++ pytorch_pretrained_gans/StudioGAN/worker.py | 995 ++++++++++++++++++ pytorch_pretrained_gans/__init__.py | 25 + .../self_conditioned/__init__.py | 123 +++ .../self_conditioned/gan_training/__init__.py | 0 .../gan_training/checkpoints.py | 163 +++ .../self_conditioned/gan_training/config.py | 116 ++ .../gan_training/distributions.py | 43 + .../self_conditioned/gan_training/eval.py | 80 ++ .../self_conditioned/gan_training/inputs.py | 217 ++++ .../self_conditioned/gan_training/logger.py | 96 ++ .../gan_training/metrics/__init__.py | 5 + .../metrics/clustering_metrics.py | 41 + .../gan_training/metrics/fid.py | 304 ++++++ .../gan_training/metrics/inception_score.py | 66 ++ .../gan_training/metrics/tf_is/LICENSE | 201 ++++ .../gan_training/metrics/tf_is/README.md | 23 + .../metrics/tf_is/inception_score.py | 116 ++ .../gan_training/models/__init__.py | 13 + .../gan_training/models/blocks.py | 205 ++++ .../gan_training/models/dcgan_deep.py | 139 +++ .../gan_training/models/dcgan_shallow.py | 134 +++ .../gan_training/models/resnet2.py | 187 ++++ .../gan_training/models/resnet2s.py | 186 ++++ .../gan_training/models/resnet3.py | 161 +++ .../self_conditioned/gan_training/train.py | 152 +++ .../self_conditioned/gan_training/utils.py | 52 + .../stylegan2_ada_pytorch/__init__.py | 86 ++ .../stylegan2_ada_pytorch/dnnlib/__init__.py | 9 + .../stylegan2_ada_pytorch/dnnlib/util.py | 477 +++++++++ .../torch_utils/__init__.py | 9 + .../torch_utils/custom_ops.py | 126 +++ .../stylegan2_ada_pytorch/torch_utils/misc.py | 262 +++++ .../torch_utils/ops/__init__.py | 9 + .../torch_utils/ops/bias_act.cpp | 99 ++ .../torch_utils/ops/bias_act.cu | 173 +++ .../torch_utils/ops/bias_act.h | 38 + .../torch_utils/ops/bias_act.py | 212 ++++ .../torch_utils/ops/conv2d_gradfix.py | 170 +++ .../torch_utils/ops/conv2d_resample.py | 156 +++ .../torch_utils/ops/fma.py | 60 ++ .../torch_utils/ops/grid_sample_gradfix.py | 83 ++ .../torch_utils/ops/upfirdn2d.cpp | 103 ++ .../torch_utils/ops/upfirdn2d.cu | 350 ++++++ .../torch_utils/ops/upfirdn2d.h | 59 ++ .../torch_utils/ops/upfirdn2d.py | 384 +++++++ .../torch_utils/persistence.py | 251 +++++ .../torch_utils/training_stats.py | 268 +++++ setup.py | 14 + 127 files changed, 19530 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 examples/generate.py create mode 100644 pytorch_pretrained_gans/.gitignore create mode 100644 pytorch_pretrained_gans/BigBiGAN/__init__.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/gan_load.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/gan_with_shift.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/BigGAN.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/__init__.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/layers.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/__init__.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm_reimpl.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/comm.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/replicate.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/unittest.py create mode 100644 pytorch_pretrained_gans/BigBiGAN/weights/download.sh create mode 100644 pytorch_pretrained_gans/BigGAN/__init__.py create mode 100644 pytorch_pretrained_gans/BigGAN/config.py create mode 100644 pytorch_pretrained_gans/BigGAN/convert_tf_to_pytorch.py create mode 100644 pytorch_pretrained_gans/BigGAN/file_utils.py create mode 100644 pytorch_pretrained_gans/BigGAN/model.py create mode 100644 pytorch_pretrained_gans/BigGAN/utils.py create mode 100644 pytorch_pretrained_gans/CIPS/GeneratorsCIPS.py create mode 100644 pytorch_pretrained_gans/CIPS/__init__.py create mode 100644 pytorch_pretrained_gans/CIPS/blocks.py create mode 100644 pytorch_pretrained_gans/CIPS/op/__init__.py create mode 100644 pytorch_pretrained_gans/CIPS/op/fused_act.py create mode 100644 pytorch_pretrained_gans/CIPS/op/fused_bias_act.cpp create mode 100644 pytorch_pretrained_gans/CIPS/op/fused_bias_act_kernel.cu create mode 100644 pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp create mode 100644 pytorch_pretrained_gans/CIPS/op/upfirdn2d.py create mode 100644 pytorch_pretrained_gans/CIPS/op/upfirdn2d_kernel.cu create mode 100644 pytorch_pretrained_gans/StudioGAN/__init__.py create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN2048/BigGAN2048.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN256/BigGAN256.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN2048/ContraGAN2048.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN256/ContraGAN256.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SAGAN/SAGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SNGAN/SNGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/ImageNet/download.sh create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ACGAN/ACGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/BigGAN/BigGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ContraGAN/DiffAugGAN(C).json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/GGAN/GGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/LSGAN/LSGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ProjGAN/ProjGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SAGAN/SAGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SNGAN/SNGAN.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/WGAN-WC/WGAN-WC.json create mode 100644 pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/download.sh create mode 100644 pytorch_pretrained_gans/StudioGAN/loader.py create mode 100644 pytorch_pretrained_gans/StudioGAN/main.py create mode 100644 pytorch_pretrained_gans/StudioGAN/models/big_resnet.py create mode 100644 pytorch_pretrained_gans/StudioGAN/models/big_resnet_deep.py create mode 100644 pytorch_pretrained_gans/StudioGAN/models/resnet.py create mode 100644 pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm.py create mode 100644 pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm_reimpl.py create mode 100644 pytorch_pretrained_gans/StudioGAN/sync_batchnorm/comm.py create mode 100644 pytorch_pretrained_gans/StudioGAN/sync_batchnorm/replicate.py create mode 100644 pytorch_pretrained_gans/StudioGAN/sync_batchnorm/unittest.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/ada.py create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/__init__.py create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_act.py create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act.cpp create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act_kernel.cu create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.cpp create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.py create mode 100755 pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d_kernel.cu create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/biggan_utils.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/cr_diff_aug.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/diff_aug.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/load_checkpoint.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/log.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/losses.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/make_hdf5.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/misc.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/model_ops.py create mode 100644 pytorch_pretrained_gans/StudioGAN/utils/sample.py create mode 100644 pytorch_pretrained_gans/StudioGAN/worker.py create mode 100644 pytorch_pretrained_gans/__init__.py create mode 100644 pytorch_pretrained_gans/self_conditioned/__init__.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/__init__.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/checkpoints.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/config.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/distributions.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/eval.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/inputs.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/logger.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/__init__.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/clustering_metrics.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/fid.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/inception_score.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/LICENSE create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/README.md create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/inception_score.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/__init__.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/blocks.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_deep.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_shallow.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2s.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet3.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/train.py create mode 100644 pytorch_pretrained_gans/self_conditioned/gan_training/utils.py create mode 100644 pytorch_pretrained_gans/stylegan2_ada_pytorch/__init__.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/__init__.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/util.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/__init__.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/custom_ops.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/misc.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/__init__.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cu create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.h create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_gradfix.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_resample.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/fma.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cu create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/persistence.py create mode 100755 pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/training_stats.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..601cb44 --- /dev/null +++ b/.gitignore @@ -0,0 +1,137 @@ +# Custom +.vscode +wandb +outputs +tmp* +slurm-logs +inversion/gans/BigGAN/weights + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +.github + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Lightning /research +test_tube_exp/ +tests/tests_tt_dir/ +tests/save_dir +default/ +data/ +test_tube_logs/ +test_tube_data/ +datasets/ +model_weights/ +tests/save_dir +tests/tests_tt_dir/ +processed/ +raw/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDEs +.idea +.vscode + +# seed project +lightning_logs/ +MNIST +.DS_Store diff --git a/README.md b/README.md new file mode 100644 index 0000000..ef84cf8 --- /dev/null +++ b/README.md @@ -0,0 +1,145 @@ +
+ +## PyTorch Pretrained GANs + + + +
+ + + +### Quick Start +This repository provides a standardized interface for pretrained GANs in PyTorch. You can install it with: +```bash +pip install git+https://github.com/lukemelas/pytorch-pretrained-gans +``` +It is then easy to generate an image with a GAN: +```python +import torch +from pytorch_pretrained_gans import make_gan + +# Sample a class-conditional image from BigGAN with default resolution 256 +G = make_gan(gan_type='biggan') # -> nn.Module +y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) +x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) +``` + +### Motivation +Over the past few years, great progress has been made in generative modeling using GANs. As a result, a large body of research has emerged that uses GANs and explores/interprets their latent spaces. I recently worked on a project in which I wanted to apply the same technique to a bunch of different GANs (here's the [paper](https://github.com/lukemelas/unsupervised-image-segmentation) if you're interested). This was quite a pain because all the pretrained GANs out there are in completely different formats. So I decided to standardize them, and here's the result. I hope you find it useful. + +### Installation +Install with `pip` directly from GitHub: +``` +pip install git+https://github.com/lukemelas/pytorch-pretrained-gans +``` + +### Available GANs + +The following GANs are available. If you would like to add a new GAN to the repo, please submit a pull request -- I would love to add to this list: + - [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch) + - [BigBiGAN](https://arxiv.org/abs/1907.02544) + - [StyleGAN-2-ADA](https://arxiv.org/abs/1912.04958) + - [Self-Conditioned GANs](https://arxiv.org/abs/2006.10728) + - [Many GANs from StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN): + - [BigGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) (a reimplementation) + - [ContraGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) + - [SAGAN](https://arxiv.org/abs/1805.08318) + - [SNGAN](https://arxiv.org/abs/1802.05957) + + + +### Structure + +This repo supports both conditional and unconditional GANs. The standard GAN interface is as follows: + +```python +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format.""" + + def __init__(self, G, num_classes=None): + super().__init__() + self.G : nn.Module = # GAN generator + self.dim_z : int = # dimensionality of latent space + self.conditional = # True / False + + def forward(self, z, y=None): # y is for conditional GAN only + x = # ... generate image from latent with self.G + return x # returns image + + def sample_latent(self, batch_size, device='cpu'): + z = # ... samples latent vector of size self.dim_z + return z + + def sample_class(self, batch_size=None, device='cpu'): + y = # ... samples class y (for conditional GAN only) + return y +``` + +Each type of GAN is contained in its own folder and has a `make_GAN_TYPE` function. For example, `make_bigbigan` creates a BigBiGAN with the format of the `GeneratorWrapper` above. + +The weights of all GANs except those in PyTorch-StudioGAN and are downloaded automatically. To download the PyTorch-StudioGAN weights, use the `download.sh` scripts in the corresponding folders (see the file structure below). + +#### Code Structure +The structure of the repo is below. Each type of GAN has an `__init__.py` file that defines its `GeneratorWrapper` and its `make_GAN_TYPE` file. + +```bash +pytorch_pretrained_gans +├── __init__.py +├── BigBiGAN +│   ├── __init__.py +│   ├── ... +│   └── weights +│      └── download.sh # (use this to download pretrained weights) +├── BigGAN +│   ├── __init__.py +│   ├── ... +├── StudioGAN +│   ├── __init__.py +│   ├── ... +│   ├── configs +│   │   ├── ImageNet +│   │   │   ├── BigGAN2048 +│   │   │   │   └── ... +│   │   │   └── download.sh # (use this to download pretrained weights) +│   │   └── TinyImageNet +│   │   ├── ACGAN +│   │   │   └── ACGAN.json +│   │   ├── ... +│   │   └── download.sh # (use this to download pretrained weights) +├── self_conditioned +│   ├── __init__.py +│   └── ... +└── stylegan2_ada_pytorch +    ├── __init__.py +    └── ... +``` + +### GAN-Specific Details + +Naturally, there are some details that are specific to certain GANs. + +**BigGAN:** For BigGAN, you should specify a resolution with `model_name`. For example: + * `G = make_gan(gan_type='biggan', model_name='biggan-deep-512')` + +**StudioGAN:** For StudioGAN, you should specify a model with `model_name`. For example: + * `G = make_gan(gan_type='studiogan', model_name='SAGAN')` + * `G = make_gan(gan_type='studiogan', model_name='ContraGAN256')` + +**Self-Conditioned GAN:** For StudioGAN, you should specify a model (either `self_conditioned` or `unconditional`) with `model_name`. For example: + * `G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned')` + +**StyleGAN 2:** + * StyleGAN2's `sample_latent` method returns `w`, not `z`, because this is usually what is desired. `w` has shape `torch.Size([1, 18, 512])`. + * StyleGAN2 is currently not implemented on `CPU` + +### Citation +Please cite the following if you use this repo in a research paper: +```bibtex +@inproceedings{melaskyriazi2021finding, + author = {Melas-Kyriazi, Luke and Manrai, Arjun}, + title = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models}, + booktitle = arxiv, + year = {2021} +} +``` diff --git a/examples/generate.py b/examples/generate.py new file mode 100644 index 0000000..cb154d9 --- /dev/null +++ b/examples/generate.py @@ -0,0 +1,34 @@ +""" +Example usage: +python generate.py +""" +import torch +from pytorch_pretrained_gans import make_gan + +# BigGAN (unconditional) +G = make_gan(gan_type='biggan', model_type='biggan-deep-512') # -> nn.Module +y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) +x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) + +# BigBiGAN (unconditional) +G = make_gan(gan_type='bigbigan') # -> nn.Module +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120]) +x = G(z=z) # -> torch.Size([1, 3, 128, 128]) + +# Self-Conditioned GAN (unconditional) +G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned') # -> nn.Module +y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) +x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) + +# StudioGAN (unconditional) +G = make_gan(gan_type='studiogan', model_name='SAGAN') # -> nn.Module +y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) +x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) + +# StyleGAN2 (unconditional) +G = make_gan(gan_type='stylegan2') # -> nn.Module +z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120]) +x = G(z=z) # -> torch.Size([1, 3, 128, 128]) diff --git a/pytorch_pretrained_gans/.gitignore b/pytorch_pretrained_gans/.gitignore new file mode 100644 index 0000000..148c45e --- /dev/null +++ b/pytorch_pretrained_gans/.gitignore @@ -0,0 +1,4 @@ +*.pt +*.pth +*.pkl +*.npy diff --git a/pytorch_pretrained_gans/BigBiGAN/__init__.py b/pytorch_pretrained_gans/BigBiGAN/__init__.py new file mode 100644 index 0000000..aa92c1e --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/__init__.py @@ -0,0 +1 @@ +from .gan_load import make_big_gan diff --git a/pytorch_pretrained_gans/BigBiGAN/gan_load.py b/pytorch_pretrained_gans/BigBiGAN/gan_load.py new file mode 100644 index 0000000..096a1a0 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/gan_load.py @@ -0,0 +1,65 @@ +from pathlib import Path +import torch +from torch import nn +from .model import BigGAN +from .gan_with_shift import gan_with_shift + +DEFAULT_WEIGHTS_ROOT = Path(__file__).parent / 'weights/BigBiGAN_x1.pth' + + +class GeneratorWrapper(nn.Module): + """ A wrapper to put the GAN in a standard format -- here, a modified + version of the old UnconditionalBigGAN class """ + + def __init__(self, big_gan): + super().__init__() + self.big_gan = big_gan + self.dim_z = self.big_gan.dim_z + self.conditional = False + + def forward(self, z): + classes = torch.zeros(z.shape[0], dtype=torch.int64, device=z.device) + return self.big_gan(z, self.big_gan.shared(classes)) + + def sample_latent(self, batch_size, device): + z = torch.randn((batch_size, self.dim_z), device=device) + return z + + +def make_biggan_config(resolution): + attn_dict = {128: '64', 256: '128', 512: '64'} + dim_z_dict = {128: 120, 256: 140, 512: 128} + config = { + 'G_param': 'SN', 'D_param': 'SN', + 'G_ch': 96, 'D_ch': 96, + 'D_wide': True, 'G_shared': True, + 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], + 'hier': True, 'cross_replica': False, + 'mybn': False, 'G_activation': nn.ReLU(inplace=True), + 'G_attn': attn_dict[resolution], + 'norm_style': 'bn', + 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, + 'G_fp16': False, 'G_mixed_precision': False, + 'accumulate_stats': False, 'num_standing_accumulations': 16, + 'G_eval_mode': True, + 'BN_eps': 1e-04, 'SN_eps': 1e-04, + 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, + 'n_classes': 1000} + return config + + +@gan_with_shift +def make_big_gan(weights_root, resolution): + config = make_biggan_config(resolution) + G = BigGAN.Generator(**config) + G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False) + return GeneratorWrapper(G) + + +def make_bigbigan(model_name='bigbigan-128', weights_root=DEFAULT_WEIGHTS_ROOT): + assert model_name == 'bigbigan-128' + config = make_biggan_config(resolution=128) + G = BigGAN.Generator(**config) + G.load_state_dict(torch.load(weights_root, map_location=torch.device('cpu')), strict=False) + G = GeneratorWrapper(G) + return G diff --git a/pytorch_pretrained_gans/BigBiGAN/gan_with_shift.py b/pytorch_pretrained_gans/BigBiGAN/gan_with_shift.py new file mode 100644 index 0000000..1a89bb2 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/gan_with_shift.py @@ -0,0 +1,20 @@ +import types +from functools import wraps + + +def add_forward_with_shift(generator): + def gen_shifted(self, z, shift, *args, **kwargs): + return self.forward(z + shift, *args, **kwargs) + + generator.gen_shifted = types.MethodType(gen_shifted, generator) + generator.dim_shift = generator.dim_z + + +def gan_with_shift(gan_factory): + @wraps(gan_factory) + def wrapper(*args, **kwargs): + gan = gan_factory(*args, **kwargs) + add_forward_with_shift(gan) + return gan + + return wrapper diff --git a/pytorch_pretrained_gans/BigBiGAN/model/BigGAN.py b/pytorch_pretrained_gans/BigBiGAN/model/BigGAN.py new file mode 100644 index 0000000..0fa4b3d --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/BigGAN.py @@ -0,0 +1,454 @@ +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F + +from . import layers + + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample': [True] * 7, + 'resolution': [8, 16, 32, 64, 128, 256, 512], + 'attention': {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3, 10)}} + arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample': [True] * 6, + 'resolution': [8, 16, 32, 64, 128, 256], + 'attention': {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3, 9)}} + arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample': [True] * 5, + 'resolution': [8, 16, 32, 64, 128], + 'attention': {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3, 8)}} + arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]], + 'out_channels': [ch * item for item in [16, 8, 4, 2]], + 'upsample': [True] * 4, + 'resolution': [8, 16, 32, 64], + 'attention': {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3, 7)}} + arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]], + 'out_channels': [ch * item for item in [4, 4, 4]], + 'upsample': [True] * 3, + 'resolution': [8, 16, 32], + 'attention': {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}} + + return arch + + +class Generator(nn.Module): + def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + # If using hierarchical latents, adjust z + if self.hier: + # Number of places z slots into + self.num_slots = len(self.arch['in_channels']) + 1 + self.z_chunk_size = (self.dim_z // self.num_slots) + # Recalculate latent dimensionality for even splitting into chunks + self.dim_z = self.z_chunk_size * self.num_slots + else: + self.num_slots = 1 + self.z_chunk_size = 0 + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.z_chunk_size if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z // self.num_slots, + self.arch['in_channels'][0] * (self.bottom_width ** 2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] else None))]] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + # print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward(self, z, y, h_shift=None, h_replace=False): + # If hierarchical, concatenate zs and ys + if self.hier: + zs = torch.split(z, self.z_chunk_size, 1) + z = zs[0] + ys = [torch.cat([y, item], 1) for item in zs[1:]] + else: + ys = [y] * len(self.blocks) + + # First linear layer + h = self.linear(z) + if h_shift is not None: + if h_replace: + h = h_shift + else: + h = h + h_shift + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, ys[index]) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample': [True] * 6 + [False], + 'resolution': [128, 64, 32, 16, 8, 4, 4], + 'attention': {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2, 8)}} + arch[128] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample': [True] * 5 + [False], + 'resolution': [64, 32, 16, 8, 4, 4], + 'attention': {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2, 8)}} + arch[64] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample': [True] * 4 + [False], + 'resolution': [32, 16, 8, 4, 4], + 'attention': {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2, 7)}} + arch[32] = {'in_channels': [3] + [item * ch for item in [4, 4, 4]], + 'out_channels': [item * ch for item in [4, 4, 4, 4]], + 'downsample': [True, True, False, False], + 'resolution': [16, 16, 16, 16], + 'attention': {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2, 6)}} + return arch + + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=128, + D_kernel_size=3, D_attn='64', n_classes=1000, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, + SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN', **kwargs): + super(Discriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Set up optimizer + self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps + if D_mixed_precision: + print('Using fp16 adam in D...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. + + +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out diff --git a/pytorch_pretrained_gans/BigBiGAN/model/__init__.py b/pytorch_pretrained_gans/BigBiGAN/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_pretrained_gans/BigBiGAN/model/layers.py b/pytorch_pretrained_gans/BigBiGAN/model/layers.py new file mode 100644 index 0000000..7632bf1 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/layers.py @@ -0,0 +1,461 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter as P + +from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d + + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2, 2]) + g = F.max_pool2d(self.g(x), [2, 2]) + # Perform reshapes + theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m ** 2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn',): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s += ' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size = output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1, -1, 1, 1) + bias = self.bias.view(1, -1, 1, 1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.out_channels) + self.conv2 = self.which_conv(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None,): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + +# dogball diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/__init__.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/__init__.py new file mode 100644 index 0000000..bc8709d --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..5453729 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) +# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input, gain=None, bias=None): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + out = F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + if gain is not None: + out = out + gain + if bias is not None: + out = out + bias + return out + + # Resize the input to (B, C, -1). + input_shape = input.size() + # print(input_shape) + input = input.view(input.size(0), input.size(1), -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + # Reduce-and-broadcast the statistics. + # print('it begins') + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + # if self._parallel_id == 0: + # # print('here') + # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + # else: + # # print('there') + # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # print('how2') + # num = sum_size + # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) + # Fix the graph + # sum = (sum.detach() - input_sum.detach()) + input_sum + # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum + + # mean = sum / num + # var = ssum / num - mean ** 2 + # # var = (ssum - mean * sum) / num + # inv_std = torch.rsqrt(var + self.eps) + + # Compute the output. + if gain is not None: + # print('gaining') + # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) + # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) + # output = input * scale - shift + output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) + elif self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + # print('a') + # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) + # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) + # print('b') + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + return mean, torch.rsqrt(bias_var + self.eps) + # return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm_reimpl.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000..7afcdaf --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNormReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/comm.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/comm.py new file mode 100644 index 0000000..922f8c4 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/replicate.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/replicate.py new file mode 100644 index 0000000..b71c7b8 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/unittest.py b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/unittest.py new file mode 100644 index 0000000..bed56f1 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/pytorch_pretrained_gans/BigBiGAN/weights/download.sh b/pytorch_pretrained_gans/BigBiGAN/weights/download.sh new file mode 100644 index 0000000..63e7c59 --- /dev/null +++ b/pytorch_pretrained_gans/BigBiGAN/weights/download.sh @@ -0,0 +1 @@ +wget https://www.dropbox.com/s/9w2i45h455k3b4p/BigBiGAN_x1.pth diff --git a/pytorch_pretrained_gans/BigGAN/__init__.py b/pytorch_pretrained_gans/BigGAN/__init__.py new file mode 100644 index 0000000..8c6c398 --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/__init__.py @@ -0,0 +1,78 @@ +import contextlib +import os +import torch +import numpy as np + +from .model import BigGAN + + +def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): + """ Create a truncated noise vector. + Params: + batch_size: batch size. + dim_z: dimension of z + truncation: truncation value to use + seed: seed for the random generator + Output: + array of shape (batch_size, dim_z) + """ + from scipy.stats import truncnorm + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) + return truncation * values + + +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format """ + + def __init__(self, G, truncation=0.4): + super().__init__() + self.G = G + self.dim_z = G.config.z_dim + self.conditional = True + self.num_classes = 1000 + + self.truncation = truncation + + def forward(self, z, y=None, return_y=False): + """ In original code, z -> noise_vector, y -> class_vector """ + if y is None: + y = self.sample_class(batch_size=z.shape[0], device=z.device) + elif y.dtype == torch.long: + y = torch.eye(self.num_classes, dtype=torch.float, device=y.device)[y] + else: + y = y.to(z.device) + x = self.G(z, y, truncation=self.truncation) + x = torch.clamp(x, min=-1, max=1) # this shouldn't really be necessary + return (x, y) if return_y else x + + def sample_latent(self, batch_size=None, device='cpu'): + z = truncated_noise_sample(truncation=self.truncation, batch_size=batch_size) + z = torch.from_numpy(z).to(device) + return z + + def sample_class(self, batch_size=None, device='cpu'): + y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) + y = torch.eye(self.num_classes, dtype=torch.float, device=device)[y] + return y + + +def make_biggan(model_name='biggan-deep-256') -> torch.nn.Module: + G = BigGAN.from_pretrained(model_name).eval() + G = GeneratorWrapper(G) + return G.eval() + + +if __name__ == '__main__': + # Testing + device = torch.device('cuda') + G = make_pretrained_biggan('biggan-deep-512') + G.to(device).eval() + print('Created G') + print(f'Params: {sum(p.numel() for p in G.parameters()):_}') + z = torch.randn([1, G.dim_z]).to(device) + print(f'z.shape: {z.shape}') + x = G(z) + print(f'x.shape: {x.shape}') + print(f'x.max(): {x.max()}') + print(f'x.min(): {x.min()}') diff --git a/pytorch_pretrained_gans/BigGAN/config.py b/pytorch_pretrained_gans/BigGAN/config.py new file mode 100644 index 0000000..454236a --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/config.py @@ -0,0 +1,70 @@ +# coding: utf-8 +""" +BigGAN config. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import copy +import json + +class BigGANConfig(object): + """ Configuration class to store the configuration of a `BigGAN`. + Defaults are for the 128x128 model. + layers tuple are (up-sample in the layer ?, input channels, output channels) + """ + def __init__(self, + output_dim=128, + z_dim=128, + class_embed_dim=128, + channel_width=128, + num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, + eps=1e-4, + n_stats=51): + """Constructs BigGANConfig. """ + self.output_dim = output_dim + self.z_dim = z_dim + self.class_embed_dim = class_embed_dim + self.channel_width = channel_width + self.num_classes = num_classes + self.layers = layers + self.attention_layer_position = attention_layer_position + self.eps = eps + self.n_stats = n_stats + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" + config = BigGANConfig() + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BigGANConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" diff --git a/pytorch_pretrained_gans/BigGAN/convert_tf_to_pytorch.py b/pytorch_pretrained_gans/BigGAN/convert_tf_to_pytorch.py new file mode 100644 index 0000000..7ccb787 --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/convert_tf_to_pytorch.py @@ -0,0 +1,312 @@ +# coding: utf-8 +""" +Convert a TF Hub model for BigGAN in a PT one. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +from itertools import chain + +import os +import argparse +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.functional import normalize + +from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME +from .config import BigGANConfig + +logger = logging.getLogger(__name__) + + +def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None): + try: + import numpy as np + import tensorflow as tf + import tensorflow_hub as hub + except ImportError: + raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. " + "Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. " + "And see https://github.com/tensorflow/hub for installing Hub. " + "Probably pip install tensorflow tensorflow-hub") + tf.reset_default_graph() + logger.info('Loading BigGAN module from: {}'.format(tf_model_path)) + module = hub.Module(tf_model_path) + inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} + output = module(inputs) + + initializer = tf.global_variables_initializer() + sess = tf.Session() + stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ()) + numpy_stacks = [] + for i in stacks: + logger.info("Retrieving module_apply_default/stack_{}".format(i)) + try: + stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i) + except KeyError: + break # We have all the stats + numpy_stacks.append(sess.run(stack_var)) + + if batch_norm_stats_path is not None: + torch.save(numpy_stacks, batch_norm_stats_path) + else: + return numpy_stacks + + +def build_tf_to_pytorch_map(model, config): + """ Build a map from TF variables to PyTorch modules. """ + tf_to_pt_map = {} + + # Embeddings and GenZ + tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight, + 'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias, + 'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig, + 'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u}) + + # GBlock blocks + model_layer_idx = 0 + for i, (up, in_channels, out_channels) in enumerate(config.layers): + if i == config.attention_layer_position: + model_layer_idx += 1 + layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/" + layer_pnt = model.generator.layers[model_layer_idx] + for i in range(4): # Batchnorms + batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/") + batch_pnt = getattr(layer_pnt, 'bn_%d' % i) + for name in ('offset', 'scale'): + sub_module_str = batch_str + name + "/" + sub_module_pnt = getattr(batch_pnt, name) + tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, + sub_module_str + "u0": sub_module_pnt.weight_u}) + for i in range(4): # Convolutions + conv_str = layer_str + "conv%d/" % i + conv_pnt = getattr(layer_pnt, 'conv_%d' % i) + tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias, + conv_str + "w/ema_0.9999": conv_pnt.weight_orig, + conv_str + "u0": conv_pnt.weight_u}) + model_layer_idx += 1 + + # Attention block + layer_str = "Generator/attention/" + layer_pnt = model.generator.layers[config.attention_layer_position] + tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma}) + for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'], + ['g/', 'o_conv/', 'phi/', 'theta/']): + sub_module_str = layer_str + tf_name + sub_module_pnt = getattr(layer_pnt, pt_name) + tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, + sub_module_str + "u0": sub_module_pnt.weight_u}) + + # final batch norm and conv to rgb + layer_str = "Generator/BatchNorm/" + layer_pnt = model.generator.bn + tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias, + layer_str + "scale/ema_0.9999": layer_pnt.weight}) + layer_str = "Generator/conv_to_rgb/" + layer_pnt = model.generator.conv_to_rgb + tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias, + layer_str + "w/ema_0.9999": layer_pnt.weight_orig, + layer_str + "u0": layer_pnt.weight_u}) + return tf_to_pt_map + + +def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None): + """ Load tf checkpoints and standing statistics in a pytorch model + """ + try: + import numpy as np + import tensorflow as tf + except ImportError: + raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + # Load weights from TF model + checkpoint_path = tf_model_path + "/variables/variables" + init_vars = tf.train.list_variables(checkpoint_path) + from pprint import pprint + pprint(init_vars) + + # Extract batch norm statistics from model if needed + if batch_norm_stats_path: + stats = torch.load(batch_norm_stats_path) + else: + logger.info("Extracting batch norm stats") + stats = extract_batch_norm_stats(tf_model_path) + + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model, config) + + tf_weights = {} + for name in tf_to_pt_map.keys(): + array = tf.train.load_variable(checkpoint_path, name) + tf_weights[name] = array + # logger.info("Loading TF weight {} with shape {}".format(name, array.shape)) + + # Load parameters + with torch.no_grad(): + pt_params_pnt = set() + for name, pointer in tf_to_pt_map.items(): + array = tf_weights[name] + if pointer.dim() == 1: + if pointer.dim() < array.ndim: + array = np.squeeze(array) + elif pointer.dim() == 2: # Weights + array = np.transpose(array) + elif pointer.dim() == 4: # Convolutions + array = np.transpose(array, (3, 2, 0, 1)) + else: + raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape)) + if pointer.shape != array.shape: + raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape))) + logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape)) + pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array) + tf_weights.pop(name, None) + pt_params_pnt.add(pointer.data_ptr()) + + # Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model): + for module in model.modules(): + for n, buffer in module.named_buffers(): + if n == 'weight_v': + weight_mat = module.weight_orig + weight_mat = weight_mat.reshape(weight_mat.size(0), -1) + u = module.weight_u + + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps) + buffer.data = v + pt_params_pnt.add(buffer.data_ptr()) + + u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps) + module.weight_u.data = u + pt_params_pnt.add(module.weight_u.data_ptr()) + + # Load batch norm statistics + index = 0 + for layer in model.generator.layers: + if not hasattr(layer, 'bn_0'): + continue + for i in range(4): # Batchnorms + bn_pointer = getattr(layer, 'bn_%d' % i) + pointer = bn_pointer.running_means + if pointer.shape != stats[index].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index]) + pt_params_pnt.add(pointer.data_ptr()) + + pointer = bn_pointer.running_vars + if pointer.shape != stats[index+1].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index+1]) + pt_params_pnt.add(pointer.data_ptr()) + + index += 2 + + bn_pointer = model.generator.bn + pointer = bn_pointer.running_means + if pointer.shape != stats[index].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index]) + pt_params_pnt.add(pointer.data_ptr()) + + pointer = bn_pointer.running_vars + if pointer.shape != stats[index+1].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index+1]) + pt_params_pnt.add(pointer.data_ptr()) + + remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \ + if t.data_ptr() not in pt_params_pnt) + + logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys()))) + logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params))) + + return model + + +BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + +BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + +BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1), + (False, 1, 1), + (True, 1, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + + +def main(): + parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model") + parser.add_argument("--model_type", type=str, default="", required=True, + help="BigGAN model type (128, 256, 512)") + parser.add_argument("--tf_model_path", type=str, default="", required=True, + help="Path of the downloaded TF Hub model") + parser.add_argument("--pt_save_path", type=str, default="", + help="Folder to save the PyTorch model (default: Folder of the TF Hub model)") + parser.add_argument("--batch_norm_stats_path", type=str, default="", + help="Path of previously extracted batch norm statistics") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + if not args.pt_save_path: + args.pt_save_path = args.tf_model_path + + if args.model_type == "128": + config = BigGAN128 + elif args.model_type == "256": + config = BigGAN256 + elif args.model_type == "512": + config = BigGAN512 + else: + raise ValueError("model_type should be one of 128, 256 or 512") + + model = BigGAN(config) + model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path) + + model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME) + config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME) + + logger.info("Save model dump to {}".format(model_save_path)) + torch.save(model.state_dict(), model_save_path) + logger.info("Save configuration file to {}".format(config_save_path)) + with open(config_save_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + +if __name__ == "__main__": + main() diff --git a/pytorch_pretrained_gans/BigGAN/file_utils.py b/pytorch_pretrained_gans/BigGAN/file_utils.py new file mode 100644 index 0000000..41624ca --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/file_utils.py @@ -0,0 +1,249 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', + Path.home() / '.pytorch_pretrained_biggan')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/pytorch_pretrained_gans/BigGAN/model.py b/pytorch_pretrained_gans/BigGAN/model.py new file mode 100644 index 0000000..ca0f757 --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/model.py @@ -0,0 +1,330 @@ +# coding: utf-8 +""" BigGAN PyTorch model. + From "Large Scale GAN Training for High Fidelity Natural Image Synthesis" + By Andrew Brocky, Jeff Donahuey and Karen Simonyan. + https://openreview.net/forum?id=B1xsqj09Fm + + PyTorch version implemented from the computational graph of the TF Hub module for BigGAN. + Some part of the code are adapted from https://github.com/brain-research/self-attention-gan + + This version only comprises the generator (since the discriminator's weights are not released). + This version only comprises the "deep" version of BigGAN (see publication). +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import os +import logging +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import BigGANConfig +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin", + 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin", + 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin", +} + +PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json", + 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json", + 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json", +} + +WEIGHTS_NAME = 'pytorch_model.bin' +CONFIG_NAME = 'config.json' + + +def snconv2d(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) + +def snlinear(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) + +def sn_embedding(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) + +class SelfAttn(nn.Module): + """ Self attention Layer""" + def __init__(self, in_channels, eps=1e-12): + super(SelfAttn, self).__init__() + self.in_channels = in_channels + self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, + kernel_size=1, bias=False, eps=eps) + self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) + self.softmax = nn.Softmax(dim=-1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + _, ch, h, w = x.size() + # Theta path + theta = self.snconv1x1_theta(x) + theta = theta.view(-1, ch//8, h*w) + # Phi path + phi = self.snconv1x1_phi(x) + phi = self.maxpool(phi) + phi = phi.view(-1, ch//8, h*w//4) + # Attn map + attn = torch.bmm(theta.permute(0, 2, 1), phi) + attn = self.softmax(attn) + # g path + g = self.snconv1x1_g(x) + g = self.maxpool(g) + g = g.view(-1, ch//2, h*w//4) + # Attn_g - o_conv + attn_g = torch.bmm(g, attn.permute(0, 2, 1)) + attn_g = attn_g.view(-1, ch//2, h, w) + attn_g = self.snconv1x1_o_conv(attn_g) + # Out + out = x + self.gamma*attn_g + return out + + +class BigGANBatchNorm(nn.Module): + """ This is a batch norm module that can handle conditional input and can be provided with pre-computed + activation means and variances for various truncation parameters. + + We cannot just rely on torch.batch_norm since it cannot handle + batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances. + If you want to train this model you should add running means and variance computation logic. + """ + def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True): + super(BigGANBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.conditional = conditional + + # We use pre-computed statistics for n_stats values of truncation between 0 and 1 + self.register_buffer('running_means', torch.zeros(n_stats, num_features)) + self.register_buffer('running_vars', torch.ones(n_stats, num_features)) + self.step_size = 1.0 / (n_stats - 1) + + if conditional: + assert condition_vector_dim is not None + self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) + self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) + else: + self.weight = torch.nn.Parameter(torch.Tensor(num_features)) + self.bias = torch.nn.Parameter(torch.Tensor(num_features)) + + def forward(self, x, truncation, condition_vector=None): + # Retreive pre-computed statistics associated to this truncation + coef, start_idx = math.modf(truncation / self.step_size) + start_idx = int(start_idx) + if coef != 0.0: # Interpolate + running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef) + running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef) + else: + running_mean = self.running_means[start_idx] + running_var = self.running_vars[start_idx] + + if self.conditional: + running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) + bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) + + out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias + else: + out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, + training=False, momentum=0.0, eps=self.eps) + + return out + + +class GenBlock(nn.Module): + def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False, + n_stats=51, eps=1e-12): + super(GenBlock, self).__init__() + self.up_sample = up_sample + self.drop_channels = (in_size != out_size) + middle_size = in_size // reduction_factor + + self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps) + + self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) + + self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) + + self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps) + + self.relu = nn.ReLU() + + def forward(self, x, cond_vector, truncation): + x0 = x + + x = self.bn_0(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_0(x) + + x = self.bn_1(x, truncation, cond_vector) + x = self.relu(x) + if self.up_sample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv_1(x) + + x = self.bn_2(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_2(x) + + x = self.bn_3(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_3(x) + + if self.drop_channels: + new_channels = x0.shape[1] // 2 + x0 = x0[:, :new_channels, ...] + if self.up_sample: + x0 = F.interpolate(x0, scale_factor=2, mode='nearest') + + out = x + x0 + return out + +class Generator(nn.Module): + def __init__(self, config): + super(Generator, self).__init__() + self.config = config + ch = config.channel_width + condition_vector_dim = config.z_dim * 2 + + self.gen_z = snlinear(in_features=condition_vector_dim, + out_features=4 * 4 * 16 * ch, eps=config.eps) + + layers = [] + for i, layer in enumerate(config.layers): + if i == config.attention_layer_position: + layers.append(SelfAttn(ch*layer[1], eps=config.eps)) + layers.append(GenBlock(ch*layer[1], + ch*layer[2], + condition_vector_dim, + up_sample=layer[0], + n_stats=config.n_stats, + eps=config.eps)) + self.layers = nn.ModuleList(layers) + + self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False) + self.relu = nn.ReLU() + self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps) + self.tanh = nn.Tanh() + + def forward(self, cond_vector, truncation): + z = self.gen_z(cond_vector) + + # We use this conversion step to be able to use TF weights: + # TF convention on shape is [batch, height, width, channels] + # PT convention on shape is [batch, channels, height, width] + z = z.view(-1, 4, 4, 16 * self.config.channel_width) + z = z.permute(0, 3, 1, 2).contiguous() + + for i, layer in enumerate(self.layers): + if isinstance(layer, GenBlock): + z = layer(z, cond_vector, truncation) + else: + z = layer(z) + + z = self.bn(z, truncation) + z = self.relu(z) + z = self.conv_to_rgb(z) + z = z[:, :3, ...] + z = self.tanh(z) + return z + +class BigGAN(nn.Module): + """BigGAN Generator.""" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + + try: + resolved_model_file = cached_path(model_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error("Wrong model name, should be a valid path to a folder containing " + "a {} file and a {} file or a model name in {}".format( + WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys())) + raise + + print("Loading BigGAN model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file)) + + # Load config + config = BigGANConfig.from_json_file(resolved_config_file) + # print("Model config {}".format(config)) + + # Instantiate model. + model = cls(config, *inputs, **kwargs) + state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None) + model.load_state_dict(state_dict, strict=False) + return model + + def __init__(self, config): + super(BigGAN, self).__init__() + self.config = config + self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False) + self.generator = Generator(config) + + def forward(self, z, class_label, truncation): + assert 0 < truncation <= 1 + + embed = self.embeddings(class_label) + cond_vector = torch.cat((z, embed), dim=1) + + z = self.generator(cond_vector, truncation) + return z + + +if __name__ == "__main__": + import PIL + from .utils import truncated_noise_sample, save_as_images, one_hot_from_names + from .convert_tf_to_pytorch import load_tf_weights_in_biggan + + load_cache = False + cache_path = './saved_model.pt' + config = BigGANConfig() + model = BigGAN(config) + if not load_cache: + model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin') + torch.save(model.state_dict(), cache_path) + else: + model.load_state_dict(torch.load(cache_path)) + + model.eval() + + truncation = 0.4 + noise = truncated_noise_sample(batch_size=2, truncation=truncation) + label = one_hot_from_names('diver', batch_size=2) + + # Tests + # noise = np.zeros((1, 128)) + # label = [983] + + noise = torch.tensor(noise, dtype=torch.float) + label = torch.tensor(label, dtype=torch.float) + with torch.no_grad(): + outputs = model(noise, label, truncation) + print(outputs.shape) + + save_as_images(outputs) diff --git a/pytorch_pretrained_gans/BigGAN/utils.py b/pytorch_pretrained_gans/BigGAN/utils.py new file mode 100644 index 0000000..3b9edbe --- /dev/null +++ b/pytorch_pretrained_gans/BigGAN/utils.py @@ -0,0 +1,216 @@ +# coding: utf-8 +""" BigGAN utilities to prepare truncated noise samples and convert/save/display output images. + Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes. + We use Wordnet so you can just input a name in a string and automatically get a corresponding + imagenet class if it exists (or a hypo/hypernym exists in imagenet). +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +from io import BytesIO + +import numpy as np +from scipy.stats import truncnorm + +logger = logging.getLogger(__name__) + +NUM_CLASSES = 1000 + + +def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): + """ Create a truncated noise vector. + Params: + batch_size: batch size. + dim_z: dimension of z + truncation: truncation value to use + seed: seed for the random generator + Output: + array of shape (batch_size, dim_z) + """ + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) + return truncation * values + + +def convert_to_images(obj): + """ Convert an output tensor from BigGAN in a list of images. + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + Output: + list of Pillow Images of size (height, width) + """ + try: + import PIL + except ImportError: + raise ImportError("Please install Pillow to use images: pip install Pillow") + + if not isinstance(obj, np.ndarray): + obj = obj.detach().numpy() + + obj = obj.transpose((0, 2, 3, 1)) + obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255) + + img = [] + for i, out in enumerate(obj): + out_array = np.asarray(np.uint8(out), dtype=np.uint8) + img.append(PIL.Image.fromarray(out_array)) + return img + + +def save_as_images(obj, file_name='output'): + """ Convert and save an output tensor from BigGAN in a list of saved images. + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + """ + img = convert_to_images(obj) + + for i, out in enumerate(img): + current_file_name = file_name + '_%d.png' % i + logger.info("Saving image to {}".format(current_file_name)) + out.save(current_file_name, 'png') + + +def display_in_terminal(obj): + """ Convert and display an output tensor from BigGAN in the terminal. + This function use `libsixel` and will only work in a libsixel-compatible terminal. + Please refer to https://github.com/saitoha/libsixel for more details. + + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + """ + try: + import PIL + from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize, + sixel_dither_set_palette, sixel_dither_set_pixelformat, + sixel_dither_get, sixel_encode, sixel_dither_unref, + sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888, + SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8, + SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1) + except ImportError: + raise ImportError("Display in Terminal requires Pillow, libsixel " + "and a libsixel compatible terminal. " + "Please read info at https://github.com/saitoha/libsixel " + "and install with pip install Pillow libsixel-python") + + s = BytesIO() + + images = convert_to_images(obj) + widths, heights = zip(*(i.size for i in images)) + + output_width = sum(widths) + output_height = max(heights) + + output_image = PIL.Image.new('RGB', (output_width, output_height)) + + x_offset = 0 + for im in images: + output_image.paste(im, (x_offset,0)) + x_offset += im.size[0] + + try: + data = output_image.tobytes() + except NotImplementedError: + data = output_image.tostring() + output = sixel_output_new(lambda data, s: s.write(data), s) + + try: + if output_image.mode == 'RGBA': + dither = sixel_dither_new(256) + sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888) + elif output_image.mode == 'RGB': + dither = sixel_dither_new(256) + sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888) + elif output_image.mode == 'P': + palette = output_image.getpalette() + dither = sixel_dither_new(256) + sixel_dither_set_palette(dither, palette) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8) + elif output_image.mode == 'L': + dither = sixel_dither_get(SIXEL_BUILTIN_G8) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8) + elif output_image.mode == '1': + dither = sixel_dither_get(SIXEL_BUILTIN_G1) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1) + else: + raise RuntimeError('unexpected output_image mode') + try: + sixel_encode(data, output_width, output_height, 1, dither, output) + print(s.getvalue().decode('ascii')) + finally: + sixel_dither_unref(dither) + finally: + sixel_output_unref(output) + + +def one_hot_from_int(int_or_list, batch_size=1): + """ Create a one-hot vector from a class index or a list of class indices. + Params: + int_or_list: int, or list of int, of the imagenet classes (between 0 and 999) + batch_size: batch size. + If int_or_list is an int create a batch of identical classes. + If int_or_list is a list, we should have `len(int_or_list) == batch_size` + Output: + array of shape (batch_size, 1000) + """ + if isinstance(int_or_list, int): + int_or_list = [int_or_list] + + if len(int_or_list) == 1 and batch_size > 1: + int_or_list = [int_or_list[0]] * batch_size + + assert batch_size == len(int_or_list) + + array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32) + for i, j in enumerate(int_or_list): + array[i, j] = 1.0 + return array + + +def one_hot_from_names(class_name_or_list, batch_size=1): + """ Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). + We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. + If we can't find it direcly, we look at the hyponyms and hypernyms of the class name. + + Params: + class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch). + Output: + array of shape (batch_size, 1000) + """ + try: + from nltk.corpus import wordnet as wn + except ImportError: + raise ImportError("You need to install nltk to use this function") + + if not isinstance(class_name_or_list, (list, tuple)): + class_name_or_list = [class_name_or_list] + else: + batch_size = max(batch_size, len(class_name_or_list)) + + classes = [] + for class_name in class_name_or_list: + class_name = class_name.replace(" ", "_") + + original_synsets = wn.synsets(class_name) + original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names + if not original_synsets: + return None + + possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets)) + if possible_synsets: + classes.append(IMAGENET[possible_synsets[0].offset()]) + else: + # try hypernyms and hyponyms + possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], []) + possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets)) + if possible_synsets: + classes.append(IMAGENET[possible_synsets[0].offset()]) + + return one_hot_from_int(classes, batch_size=batch_size) + + +IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999} diff --git a/pytorch_pretrained_gans/CIPS/GeneratorsCIPS.py b/pytorch_pretrained_gans/CIPS/GeneratorsCIPS.py new file mode 100644 index 0000000..8338198 --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/GeneratorsCIPS.py @@ -0,0 +1,206 @@ +__all__ = ['CIPSskip', 'CIPSres'] + +import math + +import torch +from torch import nn +import torch.nn.functional as F + +from .blocks import ConstantInput, LFF, StyledConv, ToRGB, PixelNorm, EqualLinear, StyledResBlock + + +class CIPSskip(nn.Module): + def __init__(self, size=256, hidden_size=512, n_mlp=8, style_dim=512, lr_mlp=0.01, + activation=None, channel_multiplier=2, **kwargs): + super(CIPSskip, self).__init__() + + self.size = size + demodulate = True + self.demodulate = demodulate + self.lff = LFF(hidden_size) + self.emb = ConstantInput(hidden_size, size=size) + + self.channels = { + 0: 512, + 1: 512, + 2: 512, + 3: 512, + 4: 256 * channel_multiplier, + 5: 128 * channel_multiplier, + 6: 64 * channel_multiplier, + 7: 32 * channel_multiplier, + 8: 16 * channel_multiplier, + } + + multiplier = 2 + in_channels = int(self.channels[0]) + self.conv1 = StyledConv(int(multiplier * hidden_size), in_channels, 1, style_dim, demodulate=demodulate, + activation=activation) + + self.linears = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.log_size = int(math.log(size, 2)) + + self.n_intermediate = self.log_size - 1 + self.to_rgb_stride = 2 + for i in range(0, self.log_size - 1): + out_channels = self.channels[i] + self.linears.append(StyledConv(in_channels, out_channels, 1, style_dim, + demodulate=demodulate, activation=activation)) + self.linears.append(StyledConv(out_channels, out_channels, 1, style_dim, + demodulate=demodulate, activation=activation)) + self.to_rgbs.append(ToRGB(out_channels, style_dim, upsample=False)) + + in_channels = out_channels + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + def forward(self, + coords, + latent, + return_latents=False, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + + latent = latent[0] + + if truncation < 1: + latent = truncation_latent + truncation * (latent - truncation_latent) + + if not input_is_latent: + latent = self.style(latent) + + x = self.lff(coords) + + batch_size, _, w, h = coords.shape + if self.training and w == h == self.size: + emb = self.emb(x) + else: + emb = F.grid_sample( + self.emb.input.expand(batch_size, -1, -1, -1), + coords.permute(0, 2, 3, 1).contiguous(), + padding_mode='border', mode='bilinear', + ) + + x = torch.cat([x, emb], 1) + + rgb = 0 + + x = self.conv1(x, latent) + for i in range(self.n_intermediate): + for j in range(self.to_rgb_stride): + x = self.linears[i * self.to_rgb_stride + j](x, latent) + + rgb = self.to_rgbs[i](x, latent, rgb) + + if return_latents: + return rgb, latent + else: + return rgb, None + + +class CIPSres(nn.Module): + def __init__(self, size=256, hidden_size=512, n_mlp=8, style_dim=512, lr_mlp=0.01, + activation=None, channel_multiplier=2, **kwargs): + super(CIPSres, self).__init__() + + self.size = size + demodulate = True + self.demodulate = demodulate + self.lff = LFF(int(hidden_size)) + self.emb = ConstantInput(hidden_size, size=size) + + self.channels = { + 0: 512, + 1: 512, + 2: 512, + 3: 512, + 4: 256 * channel_multiplier, + 5: 128 * channel_multiplier, + 6: 64 * channel_multiplier, + 7: 64 * channel_multiplier, + 8: 32 * channel_multiplier, + } + + self.linears = nn.ModuleList() + in_channels = int(self.channels[0]) + multiplier = 2 + self.linears.append(StyledConv(int(multiplier * hidden_size), in_channels, 1, style_dim, demodulate=demodulate, + activation=activation)) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + for i in range(0, self.log_size - 1): + out_channels = self.channels[i] + self.linears.append(StyledResBlock(in_channels, out_channels, 1, style_dim, demodulate=demodulate, + activation=activation)) + in_channels = out_channels + + self.to_rgb_last = ToRGB(in_channels, style_dim, upsample=False) + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + def forward(self, + coords, + latent, + return_latents=False, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + + latent = latent[0] + + if truncation < 1: + latent = truncation_latent + truncation * (latent - truncation_latent) + + if not input_is_latent: + latent = self.style(latent) + + x = self.lff(coords) + + batch_size, _, w, h = coords.shape + if self.training and w == h == self.size: + emb = self.emb(x) + else: + emb = F.grid_sample( + self.emb.input.expand(batch_size, -1, -1, -1), + coords.permute(0, 2, 3, 1).contiguous(), + padding_mode='border', mode='bilinear', + ) + out = torch.cat([x, emb], 1) + + for con in self.linears: + out = con(out, latent) + + out = self.to_rgb_last(out, latent) + + if return_latents: + return out, latent + else: + return out, None diff --git a/pytorch_pretrained_gans/CIPS/__init__.py b/pytorch_pretrained_gans/CIPS/__init__.py new file mode 100644 index 0000000..b6eb459 --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/__init__.py @@ -0,0 +1,121 @@ +import torch +from typing import NamedTuple + +from .GeneratorsCIPS import CIPSskip, CIPSres + + +class Churches256Arguments(NamedTuple): + """CIPSskip for LSUN-Churches-256""" + Generator = CIPSskip + size = 256 + coords_size = 256 + fc_dim = 512 + latent = 512 + style_dim = 512 + n_mlp = 8 + activation = None + channel_multiplier = 2 + coords_integer_values = False + + +MODELS = { + # Download from https://github.com/saic-mdal/CIPS#pretrained-checkpoints + 'churches': ('/home/luke/projects/experiments/gan-seg/src/segmentation/gans/CIPS/churches_g_ema.pt', Churches256Arguments), +} + + +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format """ + + def __init__(self, g_ema, args, truncation=0.7, device='cpu'): + super().__init__() + self.G = g_ema.to(device) + self.dim_z = g_ema.style_dim + self.conditional = False + + self.truncation = truncation + self.truncation_latent = get_latent_mean(g_ema, args, device) + self.x_channel, self.y_channel = convert_to_coord_format_unbatched( + args.coords_size, args.coords_size, device, + integer_values=args.coords_integer_values) + self.coords_size = args.coords_size + + def forward(self, z): + x_channel = self.x_channel.repeat(z.size(0), 1, self.coords_size, 1).to(z.device) + y_channel = self.y_channel.repeat(z.size(0), 1, 1, self.coords_size).to(z.device) + converted_full = torch.cat((x_channel, y_channel), dim=1) + sample, _ = self.G( + coords=converted_full, + latent=[z], + return_latents=False, + truncation=self.truncation, + truncation_latent=self.truncation_latent, + input_is_latent=True) + sample = torch.clamp(sample, min=-1, max=1) # I don't know if this is needed, I think it is though + return sample + + +def convert_to_coord_format_unbatched(h, w, device='cpu', integer_values=False): + if integer_values: + x_channel = torch.arange(w, dtype=torch.float, device=device).view(1, 1, 1, -1) + y_channel = torch.arange(h, dtype=torch.float, device=device).view(1, 1, -1, 1) + else: + x_channel = torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1) + y_channel = torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1) + return (x_channel, y_channel) + + +def make_cips(model_name='churches', **kwargs) -> torch.nn.Module: + checkpoint_path, args = MODELS[model_name] + g_ema = args.Generator( + size=args.size, + hidden_size=args.fc_dim, + style_dim=args.latent, + n_mlp=args.n_mlp, + activation=args.activation, + channel_multiplier=args.channel_multiplier) + ckpt = torch.load(checkpoint_path, map_location='cpu') + g_ema.load_state_dict(ckpt) + G = GeneratorWrapper(g_ema, args, **kwargs) + return G.eval() + + +@torch.no_grad() +def get_latent_mean(g_ema, args, device): + + # Get sample input + n_sample = 1 + sample_z = torch.randn(n_sample, args.latent, device=device) + x_channel, y_channel = convert_to_coord_format_unbatched(args.coords_size, args.coords_size, device, + integer_values=args.coords_integer_values) + x_channel = x_channel.repeat(sample_z.size(0), 1, args.coords_size, 1).to(device) + y_channel = y_channel.repeat(sample_z.size(0), 1, 1, args.coords_size).to(device) + converted_full = torch.cat((x_channel, y_channel), dim=1) + + # Generate a bunch of times and + latents = [] + samples = [] + for _ in range(100): + sample_z = torch.randn(n_sample, args.latent, device=device) + sample, latent = g_ema(converted_full, [sample_z], return_latents=True) + latents.append(latent.cpu()) + samples.append(sample.cpu()) + samples = torch.cat(samples, 0) + latents = torch.cat(latents, 0) + truncation_latent = latents.mean(0).cuda() + assert len(truncation_latent.shape) == 1 and truncation_latent.size(0) == 512, 'smt wrong' + return truncation_latent + + +if __name__ == '__main__': + # Testing + device = torch.device('cuda') + G = make_cips(device=device) + print('Created G') + print(f'Params: {sum(p.numel() for p in G.parameters()):_}') + z = torch.randn([1, G.dim_z]).to(device) + print(f'z.shape: {z.shape}') + x = G(z) + print(f'x.shape: {x.shape}') + import pdb + pdb.set_trace() diff --git a/pytorch_pretrained_gans/CIPS/blocks.py b/pytorch_pretrained_gans/CIPS/blocks.py new file mode 100644 index 0000000..8ef1a3a --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/blocks.py @@ -0,0 +1,587 @@ +import math +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F + +from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + activation=None, + downsample=False, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + downsample=downsample, + ) + + self.activation = activation + self.noise = NoiseInjection() + if activation == 'sinrelu': + self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + self.activate = ScaledLeakyReLUSin() + elif activation == 'sin': + self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + self.activate = SinActivation() + else: + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + if self.activation == 'sinrelu' or self.activation == 'sin': + out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.upsample = upsample + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + if self.upsample: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class EqualConvTranspose2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(in_channel, out_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv_transpose2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + upsample=False, + padding="zero", + ): + layers = [] + + self.padding = 0 + stride = 1 + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + + if upsample: + layers.append( + EqualConvTranspose2d( + in_channel, + out_channel, + kernel_size, + padding=0, + stride=2, + bias=bias and not activate, + ) + ) + + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + else: + if not downsample: + if padding == "zero": + self.padding = (kernel_size - 1) // 2 + + elif padding == "reflect": + padding = (kernel_size - 1) // 2 + + if padding > 0: + layers.append(nn.ReflectionPad2d(padding)) + + self.padding = 0 + + elif padding != "valid": + raise ValueError('Padding should be "zero", "reflect", or "valid"') + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], kernel_size=3, downsample=True): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, kernel_size) + self.conv2 = ConvLayer(in_channel, out_channel, kernel_size, downsample=downsample) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class ConLinear(nn.Module): + def __init__(self, ch_in, ch_out, is_first=False, bias=True): + super(ConLinear, self).__init__() + self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=bias) + if is_first: + nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in)) + else: + nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in)) + + def forward(self, x): + return self.conv(x) + + +class SinActivation(nn.Module): + def __init__(self,): + super(SinActivation, self).__init__() + + def forward(self, x): + return torch.sin(x) + + +class LFF(nn.Module): + def __init__(self, hidden_size, ): + super(LFF, self).__init__() + self.ffm = ConLinear(2, hidden_size, is_first=True) + self.activation = SinActivation() + + def forward(self, x): + x = self.ffm(x) + x = self.activation(x) + return x + + +class ScaledLeakyReLUSin(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out_lr = F.leaky_relu(input[:, ::2], negative_slope=self.negative_slope) + out_sin = torch.sin(input[:, 1::2]) + out = torch.cat([out_lr, out_sin], 1) + return out * math.sqrt(2) + + +class StyledResBlock(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, style_dim, blur_kernel=[1, 3, 3, 1], demodulate=True, + activation=None, upsample=False, downsample=False): + super().__init__() + + self.conv1 = StyledConv(in_channel, out_channel, kernel_size, style_dim, + demodulate=demodulate, activation=activation) + self.conv2 = StyledConv(out_channel, out_channel, kernel_size, style_dim, + demodulate=demodulate, activation=activation, + upsample=upsample, downsample=downsample) + + if downsample or in_channel != out_channel or upsample: + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False, upsample=upsample, + ) + else: + self.skip = None + + def forward(self, input, latent): + out = self.conv1(input, latent) + out = self.conv2(out, latent) + + if self.skip is not None: + skip = self.skip(input) + else: + skip = input + + out = (out + skip) / math.sqrt(2) + + return out diff --git a/pytorch_pretrained_gans/CIPS/op/__init__.py b/pytorch_pretrained_gans/CIPS/op/__init__.py new file mode 100644 index 0000000..d0918d9 --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/pytorch_pretrained_gans/CIPS/op/fused_act.py b/pytorch_pretrained_gans/CIPS/op/fused_act.py new file mode 100644 index 0000000..39c5bfe --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/fused_act.py @@ -0,0 +1,86 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/pytorch_pretrained_gans/CIPS/op/fused_bias_act.cpp b/pytorch_pretrained_gans/CIPS/op/fused_bias_act.cpp new file mode 100644 index 0000000..a054318 --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/pytorch_pretrained_gans/CIPS/op/fused_bias_act_kernel.cu b/pytorch_pretrained_gans/CIPS/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000..8d2f03c --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp b/pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp new file mode 100644 index 0000000..b07aa20 --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/pytorch_pretrained_gans/CIPS/op/upfirdn2d.py b/pytorch_pretrained_gans/CIPS/op/upfirdn2d.py new file mode 100644 index 0000000..d4e2fdc --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/upfirdn2d.py @@ -0,0 +1,187 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] + diff --git a/pytorch_pretrained_gans/CIPS/op/upfirdn2d_kernel.cu b/pytorch_pretrained_gans/CIPS/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000..871d4fe --- /dev/null +++ b/pytorch_pretrained_gans/CIPS/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/__init__.py b/pytorch_pretrained_gans/StudioGAN/__init__.py new file mode 100644 index 0000000..f390307 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/__init__.py @@ -0,0 +1,138 @@ +import os +from os.path import join +from pathlib import Path +from omegaconf.omegaconf import OmegaConf +import numpy as np +import torch +import json +from collections.abc import MutableMapping + +from .models import resnet +from .models import big_resnet +from .models import big_resnet_deep + + +# Download here: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN#imagenet-3x128x128 +ROOT = Path(__file__).parent +ARCHS = { + 'resnet': resnet, + 'big_resnet': big_resnet, + 'big_resnet_deep': big_resnet_deep, +} + + +class Config(object): + def __init__(self, dict_): + self.__dict__.update(dict_) + + +def flatten(d): + items = [] + for k, v in d.items(): + if isinstance(v, MutableMapping): + items.extend(flatten(v).items()) + else: + items.append((k, v)) + return dict(items) + + +def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): + """ Create a truncated noise vector. + Params: + batch_size: batch size. + dim_z: dimension of z + truncation: truncation value to use + seed: seed for the random generator + Output: + array of shape (batch_size, dim_z) + """ + from scipy.stats import truncnorm + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) + return truncation * values + + +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format """ + + def __init__(self, Gen, cfgs): + super().__init__() + self.G = Gen + self.dim_z = Gen.z_dim + self.conditional = True + self.num_classes = cfgs.num_classes + + self.truncation = 1.0 + + def forward(self, z, y=None, return_y=False): + if y is not None: + # the model is conditional and the user gives us a class + y = y.to(z.device) + elif self.num_classes is not None: + # the model is conditional but the user does not give us a class + y = self.sample_class(batch_size=z.shape[0], device=z.device) + else: + # the model is unconditional + y = None + x = self.G(z, label=y, evaluation=True) + x = torch.clamp(x, min=-1, max=1) # this shouldn't really be necessary + return (x, y) if return_y else x + + def sample_latent(self, batch_size=None, device='cpu'): + z = torch.randn((batch_size, self.dim_z), device=device) + # z = truncated_noise_sample(truncation=self.truncation, batch_size=batch_size) + # z = torch.from_numpy(z).to(device) + return z + + def sample_class(self, batch_size=None, device='cpu'): + y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) + return y + + +def get_config_and_checkpoint(root): + paths = list(map(str, root.iterdir())) + checkpoint_path = [p for p in paths if '.pth' in p] + config_path = [p for p in paths if '.json' in p] + assert len(checkpoint_path) == 1, f'no checkpoint found in {root}' + assert len(config_path) == 1, f'no config found in {root}' + checkpoint_path = checkpoint_path[0] + config_path = config_path[0] + with open(config_path) as f: + cfgs = json.load(f) + cfgs = Config(flatten(cfgs)) + cfgs.mixed_precision = False + return cfgs, checkpoint_path + + +def make_studiogan(model_name='SAGAN', dataset='ImageNet') -> torch.nn.Module: + + # Get configs and model checkpoint path + cfgs, checkpoint_path = get_config_and_checkpoint(ROOT / 'configs' / dataset / model_name) + + # From: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/loader.py#L90 + Generator = ARCHS[cfgs.architecture].Generator + Gen = Generator( + cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, + cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, + cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision) + + # Checkpoint + checkpoint = torch.load(checkpoint_path, map_location='cpu') + Gen.load_state_dict(checkpoint['state_dict']) + + # Wrap + G = GeneratorWrapper(Gen, cfgs) + return G.eval() + + +if __name__ == '__main__': + # Testing + device = 'cuda' + G = make_studiogan('BigGAN2048').to(device) + print('Created G') + print(f'Params: {sum(p.numel() for p in G.parameters()):_}') + z = torch.randn([1, G.dim_z]).to(device) + print(f'z.shape: {z.shape}') + x = G(z) + print(f'x.shape: {x.shape}') + print(x.max(), x.min()) diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN2048/BigGAN2048.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN2048/BigGAN2048.json new file mode 100644 index 0000000..c6f4e24 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN2048/BigGAN2048.json @@ -0,0 +1,112 @@ +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator":false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 4, + "attention_after_nth_dis_block": 1, + "z_dim": 120, + "shared_dim": 128, + "g_conv_dim": 96, + "d_conv_dim": 96, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 8, + "d_lr": 0.0002, + "g_lr": 0.00005, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 2, + "total_step": 200000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda":"N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug":false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } + \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN256/BigGAN256.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN256/BigGAN256.json new file mode 100644 index 0000000..f8659b4 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN256/BigGAN256.json @@ -0,0 +1,112 @@ +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator":false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 4, + "attention_after_nth_dis_block": 1, + "z_dim": 120, + "shared_dim": 128, + "g_conv_dim": 96, + "d_conv_dim": 96, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0002, + "g_lr": 0.00005, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 2, + "total_step": 200000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda":"N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug":false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } + diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN2048/ContraGAN2048.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN2048/ContraGAN2048.json new file mode 100644 index 0000000..0ae8d36 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN2048/ContraGAN2048.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ContraGAN", + "pos_collected_numerator": false, + "hypersphere_dim": 1536, + "nonlinear_embed": false, + "normalize_embed": true, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 4, + "attention_after_nth_dis_block": 1, + "z_dim": 120, + "shared_dim": 128, + "g_conv_dim": 96, + "d_conv_dim": 96, + "G_depth":"N/A", + "D_depth":"N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 8, + "d_lr": 0.0002, + "g_lr": 0.00005, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 2, + "total_step": 200000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": 1.0, + "margin": 0.0, + "tempering_type": "constant", + "tempering_step": "N/A", + "start_temperature": 1.0, + "end_temperature": 1.0, + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda":"N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN256/ContraGAN256.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN256/ContraGAN256.json new file mode 100644 index 0000000..d2f55e3 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN256/ContraGAN256.json @@ -0,0 +1,112 @@ + +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ContraGAN", + "pos_collected_numerator": false, + "hypersphere_dim": 1536, + "nonlinear_embed": false, + "normalize_embed": true, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 4, + "attention_after_nth_dis_block": 1, + "z_dim": 120, + "shared_dim": 128, + "g_conv_dim": 96, + "d_conv_dim": 96, + "G_depth":"N/A", + "D_depth":"N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0002, + "g_lr": 0.00005, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 2, + "total_step": 200000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": 1.0, + "margin": 0.0, + "tempering_type": "constant", + "tempering_step": "N/A", + "start_temperature": 1.0, + "end_temperature": 1.0, + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda":"N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SAGAN/SAGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SAGAN/SAGAN.json new file mode 100644 index 0000000..e26fa58 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SAGAN/SAGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 4, + "attention_after_nth_dis_block": 1, + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 1000000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SNGAN/SNGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SNGAN/SNGAN.json new file mode 100644 index 0000000..4e49160 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SNGAN/SNGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "imagenet", + "data_path": "./data/ILSVRC2012", + "img_size": 128, + "num_classes": 1000, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0002, + "g_lr": 0.00005, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 2, + "total_step": 500000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} diff --git a/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/download.sh b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/download.sh new file mode 100644 index 0000000..848898c --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/ImageNet/download.sh @@ -0,0 +1,23 @@ +# SAGAN +# https://drive.google.com/file/d/1exrZloM2bHYJyU5_v9XUMlNcy5TA6gMN/view?usp=sharing +cd SAGAN +gdown --id 1_RTYZ0RXbVLWufE7bbWPvp8n_QJbA8K0 +cd .. + +# SNGAN +# https://drive.google.com/file/d/1L4Jk9v_vRojdj9ZpLBak8OxoahdsX5yn/view?usp=sharing +cd SNGAN +gdown --id 1L4Jk9v_vRojdj9ZpLBak8OxoahdsX5yn +cd .. + +# BigGAN2048 +# https://drive.google.com/file/d/14VIJUsYcItZrfNk_PcjglNXH_sPyd504/view?usp=sharing +cd BigGAN2048 +gdown --id 16tZIHrXFYFM6mXmEF-4YA7vO1D-s7meq +cd .. + +# ContraGAN256 +# https://drive.google.com/file/d/15ipVwbQpncc678tGdT7VsDcFCstUpP1n/view?usp=sharing +cd ContraGAN256 +gdown --id 15ipVwbQpncc678tGdT7VsDcFCstUpP1n +cd .. diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ACGAN/ACGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ACGAN/ACGAN.json new file mode 100644 index 0000000..ab2e260 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ACGAN/ACGAN.json @@ -0,0 +1,112 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ACGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": false, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } + } + \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/BigGAN/BigGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/BigGAN/BigGAN.json new file mode 100644 index 0000000..16083c7 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/BigGAN/BigGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 3, + "attention_after_nth_dis_block": 1, + "z_dim": 100, + "shared_dim": 128, + "g_conv_dim": 80, + "d_conv_dim": 80, + "G_depth":"N/A", + "D_depth":"N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 1024, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda":"N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda":"N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ContraGAN/DiffAugGAN(C).json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ContraGAN/DiffAugGAN(C).json new file mode 100644 index 0000000..f30ff2a --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ContraGAN/DiffAugGAN(C).json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "big_resnet", + "conditional_strategy": "ContraGAN", + "pos_collected_numerator": true, + "hypersphere_dim": 768, + "nonlinear_embed": false, + "normalize_embed": true, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 3, + "attention_after_nth_dis_block": 1, + "z_dim": 100, + "shared_dim": 128, + "g_conv_dim": 80, + "d_conv_dim": 80, + "G_depth":"N/A", + "D_depth":"N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 1024, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 200000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": 1.0, + "margin": 0.0, + "tempering_type": "constant", + "tempering_step": "N/A", + "start_temperature": 1.0, + "end_temperature": 1.0, + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": true, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate": "N/A", + "latent_op_step": "N/A", + "latent_op_step4eval": "N/A", + "latent_op_alpha": "N/A", + "latent_op_beta": "N/A", + "latent_norm_reg_weight": "N/A", + + "ema": true, + "ema_decay": 0.9999, + "ema_start": 20000 + } + } + } diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/GGAN/GGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/GGAN/GGAN.json new file mode 100644 index 0000000..bb5d74e --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/GGAN/GGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "no", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": false, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } + } \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/LSGAN/LSGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/LSGAN/LSGAN.json new file mode 100644 index 0000000..f86302a --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/LSGAN/LSGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "no", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": false, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "least_square", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } + } \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ProjGAN/ProjGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ProjGAN/ProjGAN.json new file mode 100644 index 0000000..368514c --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ProjGAN/ProjGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": false, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate": "N/A", + "latent_op_step": "N/A", + "latent_op_step4eval": "N/A", + "latent_op_alpha": "N/A", + "latent_op_beta": "N/A", + "latent_norm_reg_weight": "N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SAGAN/SAGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SAGAN/SAGAN.json new file mode 100644 index 0000000..49078cc --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SAGAN/SAGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": true, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": true, + "attention_after_nth_gen_block": 3, + "attention_after_nth_dis_block": 1, + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SNGAN/SNGAN.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SNGAN/SNGAN.json new file mode 100644 index 0000000..bb6bb6b --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SNGAN/SNGAN.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "ProjGAN", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": true, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "hinge", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": false, + "gradient_penalty_lambda": "N/A", + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/WGAN-WC/WGAN-WC.json b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/WGAN-WC/WGAN-WC.json new file mode 100644 index 0000000..cbee459 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/WGAN-WC/WGAN-WC.json @@ -0,0 +1,111 @@ +{ + "data_processing":{ + "dataset_name": "tiny_imagenet", + "data_path": "./data/TINY_ILSVRC2012", + "img_size": 64, + "num_classes": 200, + "batch_size4prcsing": 256, + "chunk_size": 500, + "compression": false + }, + + "train": { + "model": { + "architecture": "resnet", + "conditional_strategy": "no", + "pos_collected_numerator": false, + "hypersphere_dim": "N/A", + "nonlinear_embed": false, + "normalize_embed": false, + "g_spectral_norm": false, + "d_spectral_norm": false, + "activation_fn": "ReLU", + "attention": false, + "attention_after_nth_gen_block": "N/A", + "attention_after_nth_dis_block": "N/A", + "z_dim": 128, + "shared_dim": "N/A", + "g_conv_dim": 64, + "d_conv_dim": 64, + "G_depth": "N/A", + "D_depth": "N/A" + }, + + "optimization": { + "optimizer": "Adam", + "batch_size": 256, + "accumulation_steps": 1, + "d_lr": 0.0004, + "g_lr": 0.0001, + "momentum": "N/A", + "nesterov": "N/A", + "alpha": "N/A", + "beta1": 0.0, + "beta2": 0.999, + "g_steps_per_iter": 1, + "d_steps_per_iter": 1, + "total_step": 100000 + }, + + "loss_function": { + "adv_loss": "wasserstein", + + "contrastive_lambda": "N/A", + "margin": "N/A", + "tempering_type": "N/A", + "tempering_step": "N/A", + "start_temperature": "N/A", + "end_temperature": "N/A", + + "weight_clipping_for_dis": false, + "weight_clipping_bound": "N/A", + + "gradient_penalty_for_dis": true, + "gradient_penalty_lambda": 10.0, + + "deep_regret_analysis_for_dis": false, + "regret_penalty_lambda": "N/A", + + "cr": false, + "cr_lambda": "N/A", + + "bcr": false, + "real_lambda": "N/A", + "fake_lambda": "N/A", + + "zcr": false, + "gen_lambda": "N/A", + "dis_lambda": "N/A", + "sigma_noise": "N/A" + }, + + "initialization":{ + "g_init": "ortho", + "d_init": "ortho" + }, + + "training_and_sampling_setting":{ + "random_flip_preprocessing": true, + "diff_aug": false, + + "ada": false, + "ada_target": "N/A", + "ada_length": "N/A", + + "prior": "gaussian", + "truncated_factor": 1, + + "latent_op": false, + "latent_op_rate":"N/A", + "latent_op_step":"N/A", + "latent_op_step4eval":"N/A", + "latent_op_alpha":"N/A", + "latent_op_beta":"N/A", + "latent_norm_reg_weight":"N/A", + + "ema": false, + "ema_decay": "N/A", + "ema_start": "N/A" + } + } +} \ No newline at end of file diff --git a/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/download.sh b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/download.sh new file mode 100644 index 0000000..ea2a21f --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/download.sh @@ -0,0 +1,72 @@ + +# LSGAN +# https://drive.google.com/file/d/1Wa5CrUAoxgW730Z1MXg7QJ-O2Y_XiIOC/view?usp=sharing +mkdir LSGAN +cd LSGAN +# gdown --id 1Wa5CrUAoxgW730Z1MXg7QJ-O2Y_XiIOC +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/LSGAN.json +cd .. + +# GGAN +# https://drive.google.com/file/d/1U5644ZhZUdoJDUPLELoQHtPA9qOdZIXS/view?usp=sharing +mkdir GGAN +cd GGAN +# gdown --id 1U5644ZhZUdoJDUPLELoQHtPA9qOdZIXS +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/GGAN.json +cd .. + +# WGAN-WC +# https://drive.google.com/file/d/1TbWjWx8PhSHKmh-gv3WTYybSKWpoj8_u/view?usp=sharing +mkdir WGAN-WC +cd WGAN-WC +# gdown --id 1TbWjWx8PhSHKmh-gv3WTYybSKWpoj8_u +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/WGAN-WC.json +cd .. + +# ACGAN +# https://drive.google.com/file/d/14JkiZLONLXAP1JCfixlSPkbKxK5mG1cF/view?usp=sharing +mkdir ACGAN +cd ACGAN +# gdown --id 14JkiZLONLXAP1JCfixlSPkbKxK5mG1cF +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/ACGAN.json +cd .. + +# ProjGAN +# https://drive.google.com/file/d/1mRtit-GFIHjD--YLG-PzhThKkyOW7zoI/view?usp=sharing +mkdir ProjGAN +cd ProjGAN +# gdown --id 1mRtit-GFIHjD--YLG-PzhThKkyOW7zoI +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/ProjGAN.json +cd .. + +# SNGAN +# https://drive.google.com/file/d/1xHrk4bt0Xbatvt3hs4RoMM3E6BCBUAmw/view?usp=sharing +mkdir SNGAN +cd SNGAN +# gdown --id 1xHrk4bt0Xbatvt3hs4RoMM3E6BCBUAmw +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/SNGAN.json +cd .. + +# SAGAN +# https://drive.google.com/file/d/1vaEwUqUF_qC5uUBRNW413vt_8QMYfuoN/view?usp=sharing +mkdir SAGAN +cd SAGAN +# gdown --id 1vaEwUqUF_qC5uUBRNW413vt_8QMYfuoN +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/SAGAN.json +cd .. + +# BigGAN +# https://drive.google.com/file/d/16FqpBcB318De2HM7XS6UNFT7zs-3XD6e/view?usp=sharing +mkdir BigGAN +cd BigGAN +# gdown --id 16FqpBcB318De2HM7XS6UNFT7zs-3XD6e +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/BigGAN.json +cd .. + +# ContraGAN +# https://drive.google.com/file/d/1NKcNjtg51rfmFvTSMmZlkweMAlQJhYUA/view?usp=sharing +mkdir ContraGAN +cd ContraGAN +# gdown --id 1NKcNjtg51rfmFvTSMmZlkweMAlQJhYUA +wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/DiffAugGAN(C).json +cd .. diff --git a/pytorch_pretrained_gans/StudioGAN/loader.py b/pytorch_pretrained_gans/StudioGAN/loader.py new file mode 100644 index 0000000..52e03fe --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/loader.py @@ -0,0 +1,298 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/loader.py + + +import glob +import os +import random +from os.path import dirname, abspath, exists, join +from torchlars import LARS + +from data_utils.load_dataset import * +from metrics.inception_network import InceptionV3 +from metrics.prepare_inception_moments import prepare_inception_moments +from .utils.log import make_checkpoint_dir, make_logger +from .utils.losses import * +from .utils.load_checkpoint import load_checkpoint +from .utils.misc import * +from .utils.biggan_utils import ema, ema_DP_SyncBN +from sync_batchnorm.batchnorm import convert_model +from worker import make_worker + +import torch +from torch.utils.data import DataLoader +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + + +def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_config, model_config, hdf5_path_train): + cfgs = dict2clsattr(train_config, model_config) + + assert cfgs.bn_stat_OnTheFly * cfgs.standing_statistics == 0,\ + "You can't turn on train_statistics and standing_statistics simultaneously." + if cfgs.train_configs['train'] * cfgs.standing_statistics: + print("When training, StudioGAN does not apply standing_statistics for evaluation. " + + "After training is done, StudioGAN will accumulate batchnorm statistics and evaluate the trained model") + + prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None + + if cfgs.distributed_data_parallel: + global_rank = cfgs.nr * (gpus_per_node) + local_rank + print("Use GPU: {} for training.".format(global_rank)) + setup(global_rank, world_size) + torch.cuda.set_device(local_rank) + else: + global_rank = local_rank + + writer = SummaryWriter(log_dir=join('./logs', run_name)) if local_rank == 0 else None + if local_rank == 0: + logger = make_logger(run_name, None) + logger.info('Run name : {run_name}'.format(run_name=run_name)) + logger.info(train_config) + logger.info(model_config) + else: + logger = None + + ##### load dataset ##### + if local_rank == 0: + logger.info('Load train datasets...') + train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size, + hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing) + if cfgs.reduce_train_dataset < 1.0: + num_train = int(cfgs.reduce_train_dataset * len(train_dataset)) + train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train]) + if local_rank == 0: + logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) + + if local_rank == 0: + logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type)) + eval_mode = True if cfgs.eval_type == 'train' else False + eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size, + hdf5_path=None, random_flip=False) + if local_rank == 0: + logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset))) + + if cfgs.distributed_data_parallel: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + cfgs.batch_size = cfgs.batch_size // world_size + else: + train_sampler = None + + train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=(train_sampler is None), pin_memory=True, + num_workers=cfgs.num_workers, sampler=train_sampler, drop_last=True) + eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=False, + pin_memory=True, num_workers=cfgs.num_workers, drop_last=False) + + ##### build model ##### + if local_rank == 0: + logger.info('Build model...') + module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something']) + if local_rank == 0: + logger.info('Modules are located on models.{architecture}.'.format(architecture=cfgs.architecture)) + Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, + cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, + cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(local_rank) + + Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block, + cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed, + cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(local_rank) + + if cfgs.ema: + if local_rank == 0: + logger.info('Prepare EMA for G with decay of {}.'.format(cfgs.ema_decay)) + Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, + cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, + initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(local_rank) + if not cfgs.distributed_data_parallel and world_size > 1 and cfgs.synchronized_bn: + Gen_ema = ema_DP_SyncBN(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) + else: + Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) + else: + Gen_copy, Gen_ema = None, None + + if local_rank == 0: + logger.info(count_parameters(Gen)) + if local_rank == 0: + logger.info(Gen) + + if local_rank == 0: + logger.info(count_parameters(Dis)) + if local_rank == 0: + logger.info(Dis) + + # define loss functions and optimizers + G_loss = {'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, + 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen} + D_loss = {'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, + 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis} + + if cfgs.optimizer == "SGD": + G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), + cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) + D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), + cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) + elif cfgs.optimizer == "RMSprop": + G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), + cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) + D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), + cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) + elif cfgs.optimizer == "Adam": + G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), + cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) + D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), + cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) + else: + raise NotImplementedError + + if cfgs.LARS_optimizer: + G_optimizer = LARS(optimizer=G_optimizer, eps=1e-8, trust_coef=0.001) + D_optimizer = LARS(optimizer=D_optimizer, eps=1e-8, trust_coef=0.001) + + ##### load checkpoints if needed ##### + if cfgs.checkpoint_folder is None: + checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) + else: + when = "current" if cfgs.load_current is True else "best" + if not exists(abspath(cfgs.checkpoint_folder)): + raise NotADirectoryError + checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) + g_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] + d_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] + Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir) + Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ + load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) + if local_rank == 0: + logger = make_logger(run_name, None) + if cfgs.ema: + g_ema_checkpoint_dir = glob.glob( + join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0] + Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) + Gen_ema.source, Gen_ema.target = Gen, Gen_copy + + writer = SummaryWriter(log_dir=join('./logs', run_name)) if global_rank == 0 else None + if cfgs.train_configs['train']: + assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!" + + if local_rank == 0: + logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) + if local_rank == 0: + logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) + if cfgs.freeze_layers > -1: + prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None + + ##### wrap models with DP and convert BN to Sync BN ##### + if world_size > 1: + if cfgs.distributed_data_parallel: + if cfgs.synchronized_bn: + process_group = torch.distributed.new_group([w for w in range(world_size)]) + Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Gen, process_group) + Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Dis, process_group) + if cfgs.ema: + Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Gen_copy, process_group) + + Gen = DDP(Gen, device_ids=[local_rank]) + Dis = DDP(Dis, device_ids=[local_rank]) + if cfgs.ema: + Gen_copy = DDP(Gen_copy, device_ids=[local_rank]) + else: + Gen = DataParallel(Gen, output_device=local_rank) + Dis = DataParallel(Dis, output_device=local_rank) + if cfgs.ema: + Gen_copy = DataParallel(Gen_copy, output_device=local_rank) + + if cfgs.synchronized_bn: + Gen = convert_model(Gen).to(local_rank) + Dis = convert_model(Dis).to(local_rank) + if cfgs.ema: + Gen_copy = convert_model(Gen_copy).to(local_rank) + + ##### load the inception network and prepare first/secend moments for calculating FID ##### + if cfgs.eval: + inception_model = InceptionV3().to(local_rank) + if world_size > 1 and cfgs.distributed_data_parallel: + toggle_grad(inception_model, on=True) + inception_model = DDP(inception_model, device_ids=[local_rank], + broadcast_buffers=False, find_unused_parameters=True) + elif world_size > 1 and cfgs.distributed_data_parallel is False: + inception_model = DataParallel(inception_model, output_device=local_rank) + else: + pass + + mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, + generator=Gen, + eval_mode=cfgs.eval_type, + inception_model=inception_model, + splits=1, + run_name=run_name, + logger=logger, + device=local_rank) + + worker = make_worker( + cfgs=cfgs, + run_name=run_name, + best_step=best_step, + logger=logger, + writer=writer, + n_gpus=world_size, + gen_model=Gen, + dis_model=Dis, + inception_model=inception_model, + Gen_copy=Gen_copy, + Gen_ema=Gen_ema, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + G_loss=G_loss[cfgs.adv_loss], + D_loss=D_loss[cfgs.adv_loss], + prev_ada_p=prev_ada_p, + global_rank=global_rank, + local_rank=local_rank, + bn_stat_OnTheFly=cfgs.bn_stat_OnTheFly, + checkpoint_dir=checkpoint_dir, + mu=mu, + sigma=sigma, + best_fid=best_fid, + best_fid_checkpoint_path=best_fid_checkpoint_path, + ) + + if cfgs.train_configs['train']: + step = worker.train(current_step=step, total_step=cfgs.total_step) + + if cfgs.eval: + is_save = worker.evaluation(step=step, standing_statistics=cfgs.standing_statistics, + standing_step=cfgs.standing_step) + + if cfgs.save_images: + worker.save_images(is_generate=True, png=True, npz=True, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + + if cfgs.image_visualization: + worker.run_image_visualization(nrow=cfgs.nrow, ncol=cfgs.ncol, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + + if cfgs.k_nearest_neighbor: + worker.run_nearest_neighbor(nrow=cfgs.nrow, ncol=cfgs.ncol, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + + if cfgs.interpolation: + assert cfgs.architecture in [ + "big_resnet", "biggan_deep"], "StudioGAN does not support interpolation analysis except for biggan and biggan_deep." + worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + + if cfgs.frequency_analysis: + worker.run_frequency_analysis(num_images=len(train_dataset) // cfgs.num_classes, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) + + if cfgs.tsne_analysis: + worker.run_tsne(dataloader=eval_dataloader, + standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) diff --git a/pytorch_pretrained_gans/StudioGAN/main.py b/pytorch_pretrained_gans/StudioGAN/main.py new file mode 100644 index 0000000..75d7864 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/main.py @@ -0,0 +1,141 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/main.py + + +import json +import os +import sys +import warnings +from argparse import ArgumentParser + +from .utils.misc import * +from .utils.make_hdf5 import make_hdf5 +from .utils.log import make_run_name +from loader import prepare_train_eval + +import torch +from torch.backends import cudnn +import torch.multiprocessing as mp + + +RUN_NAME_FORMAT = ( + "{framework}-" + "{phase}-" + "{timestamp}" +) + + +def main(): + parser = ArgumentParser(add_help=False) + parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json') + parser.add_argument('--checkpoint_folder', type=str, default=None) + parser.add_argument('-current', '--load_current', action='store_true', + help='whether you load the current or best checkpoint') + parser.add_argument('--log_output_path', type=str, default=None) + + parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true') + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') + parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes') + + parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers') + parser.add_argument('--num_workers', type=int, default=8, help='') + parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', + help='whether turn on synchronized batchnorm') + parser.add_argument('-mpc', '--mixed_precision', action='store_true', + help='whether turn on mixed precision training') + parser.add_argument('-LARS', '--LARS_optimizer', action='store_true', help='whether turn on LARS optimizer') + parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', + help='whether disable pytorch autograd debugging mode') + + parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset') + parser.add_argument('-stat_otf', '--bn_stat_OnTheFly', action='store_true', + help='when evaluating, use the statistics of a batch') + parser.add_argument('-std_stat', '--standing_statistics', action='store_true') + parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm') + parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator') + + parser.add_argument('-l', '--load_all_data_in_memory', action='store_true') + parser.add_argument('-t', '--train', action='store_true') + parser.add_argument('-e', '--eval', action='store_true') + parser.add_argument('-s', '--save_images', action='store_true') + parser.add_argument('-iv', '--image_visualization', action='store_true', + help='select whether conduct image visualization') + parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', + help='select whether conduct k-nearest neighbor analysis') + parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis') + parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis') + parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis') + parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas') + parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas') + + parser.add_argument('--print_every', type=int, default=100, help='control log interval') + parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval') + parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]') + args = parser.parse_args() + + if not args.train and \ + not args.eval and \ + not args.save_images and \ + not args.image_visualization and \ + not args.k_nearest_neighbor and \ + not args.interpolation and \ + not args.frequency_analysis and \ + not args.tsne_analysis: + parser.print_help(sys.stderr) + sys.exit(1) + + if args.config_path is not None: + with open(args.config_path) as f: + model_config = json.load(f) + train_config = vars(args) + else: + raise NotImplementedError + + if model_config['data_processing']['dataset_name'] == 'cifar10': + assert train_config['eval_type'] in ['train', 'test'], "Cifar10 does not contain dataset for validation." + elif model_config['data_processing']['dataset_name'] in ['imagenet', 'tiny_imagenet', 'custom']: + assert train_config['eval_type'] == 'train' or train_config['eval_type'] == 'valid', \ + "StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets" + + if train_config['distributed_data_parallel']: + msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, frequency, and tsne analysis with DDP. " +\ + "Please change DDP with a single GPU training or DataParallel instead." + assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + train_config['interpolation'] +\ + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg + + hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \ + if train_config['load_all_data_in_memory'] else None + + if train_config['seed'] == -1: + cudnn.benchmark, cudnn.deterministic = True, False + else: + fix_all_seed(train_config['seed']) + cudnn.benchmark, cudnn.deterministic = False, True + + gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device() + world_size = gpus_per_node * train_config['nodes'] + + if world_size == 1: + warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') + + if train_config['disable_debugging_API']: + torch.autograd.set_detect_anomaly(False) + check_flag_0(model_config['train']['optimization']['batch_size'], world_size, train_config['freeze_layers'], train_config['checkpoint_folder'], + model_config['train']['model']['architecture'], model_config['data_processing']['img_size']) + + run_name = make_run_name(RUN_NAME_FORMAT, framework=train_config['config_path'].split('/')[-1][:-5], phase='train') + + if train_config['distributed_data_parallel'] and world_size > 1: + print("Train the models through DistributedDataParallel (DDP) mode.") + mp.spawn(prepare_train_eval, nprocs=gpus_per_node, args=(gpus_per_node, world_size, run_name, + train_config, model_config, hdf5_path_train)) + else: + prepare_train_eval(rank, gpus_per_node, world_size, run_name, train_config, model_config, + hdf5_path_train=hdf5_path_train) + + +if __name__ == '__main__': + main() diff --git a/pytorch_pretrained_gans/StudioGAN/models/big_resnet.py b/pytorch_pretrained_gans/StudioGAN/models/big_resnet.py new file mode 100644 index 0000000..1c8cb8f --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/models/big_resnet.py @@ -0,0 +1,441 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# models/big_resnet.py + + +from ..utils.model_ops import * +from ..utils.misc import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, g_spectral_norm, activation_fn, conditional_bn, z_dims_after_concat): + super(GenBlock, self).__init__() + self.conditional_bn = conditional_bn + + if self.conditional_bn: + self.bn1 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=in_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + self.bn2 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=out_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + else: + self.bn1 = batchnorm_2d(in_features=in_channels) + self.bn2 = batchnorm_2d(in_features=out_channels) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + def forward(self, x, label): + x0 = x + if self.conditional_bn: + x = self.bn1(x, label) + else: + x = self.bn1(x) + + x = self.activation(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample + x = self.conv2d1(x) + if self.conditional_bn: + x = self.bn2(x, label) + else: + x = self.bn2(x) + x = self.activation(x) + x = self.conv2d2(x) + + x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample + x0 = self.conv2d0(x0) + + out = x + x0 + return out + + +class Generator(nn.Module): + """Generator.""" + + def __init__(self, z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, + conditional_strategy, num_classes, initialize, G_depth, mixed_precision): + super(Generator, self).__init__() + g_in_dims_collection = {"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], + "64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]} + + g_out_dims_collection = {"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], + "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]} + bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4} + + self.z_dim = z_dim + self.shared_dim = shared_dim + self.num_classes = num_classes + self.mixed_precision = mixed_precision + conditional_bn = True if conditional_strategy in [ + "ACGAN", "ProjGAN", "ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"] else False + + self.in_dims = g_in_dims_collection[str(img_size)] + self.out_dims = g_out_dims_collection[str(img_size)] + self.bottom = bottom_collection[str(img_size)] + self.n_blocks = len(self.in_dims) + self.chunk_size = z_dim // (self.n_blocks + 1) + self.z_dims_after_concat = self.chunk_size + self.shared_dim + assert self.z_dim % (self.n_blocks + 1) == 0, "z_dim should be divided by the number of blocks " + + if g_spectral_norm: + self.linear0 = snlinear(in_features=self.chunk_size, + out_features=self.in_dims[0] * self.bottom * self.bottom) + else: + self.linear0 = linear(in_features=self.chunk_size, out_features=self.in_dims[0] * self.bottom * self.bottom) + + self.shared = embedding(self.num_classes, self.shared_dim) + + self.blocks = [] + for index in range(self.n_blocks): + self.blocks += [[GenBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + g_spectral_norm=g_spectral_norm, + activation_fn=activation_fn, + conditional_bn=conditional_bn, + z_dims_after_concat=self.z_dims_after_concat)]] + + if index + 1 == attention_after_nth_gen_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], g_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + self.bn4 = batchnorm_2d(in_features=self.out_dims[-1]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d5 = snconv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + else: + self.conv2d5 = conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + + self.tanh = nn.Tanh() + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + def forward(self, z, label, shared_label=None, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + zs = torch.split(z, self.chunk_size, 1) + z = zs[0] + if shared_label is None: + shared_label = self.shared(label) + else: + pass + labels = [torch.cat([shared_label, item], 1) for item in zs[1:]] + + act = self.linear0(z) + act = act.view(-1, self.in_dims[0], self.bottom, self.bottom) + counter = 0 + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + if isinstance(block, Self_Attn): + act = block(act) + else: + act = block(act, labels[counter]) + counter += 1 + + act = self.bn4(act) + act = self.activation(act) + act = self.conv2d5(act) + out = self.tanh(act) + return out + + +class DiscOptBlock(nn.Module): + def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn): + super(DiscOptBlock, self).__init__() + self.d_spectral_norm = d_spectral_norm + + if d_spectral_norm: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + self.bn0 = batchnorm_2d(in_features=in_channels) + self.bn1 = batchnorm_2d(in_features=out_channels) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + self.average_pooling = nn.AvgPool2d(2) + + def forward(self, x): + x0 = x + x = self.conv2d1(x) + if self.d_spectral_norm is False: + x = self.bn1(x) + x = self.activation(x) + x = self.conv2d2(x) + x = self.average_pooling(x) + + x0 = self.average_pooling(x0) + if self.d_spectral_norm is False: + x0 = self.bn0(x0) + x0 = self.conv2d0(x0) + + out = x + x0 + return out + + +class DiscBlock(nn.Module): + def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn, downsample=True): + super(DiscBlock, self).__init__() + self.d_spectral_norm = d_spectral_norm + self.downsample = downsample + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + self.ch_mismatch = False + if in_channels != out_channels: + self.ch_mismatch = True + + if d_spectral_norm: + if self.ch_mismatch or downsample: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + if self.ch_mismatch or downsample: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + if self.ch_mismatch or downsample: + self.bn0 = batchnorm_2d(in_features=in_channels) + self.bn1 = batchnorm_2d(in_features=in_channels) + self.bn2 = batchnorm_2d(in_features=out_channels) + + self.average_pooling = nn.AvgPool2d(2) + + def forward(self, x): + x0 = x + + if self.d_spectral_norm is False: + x = self.bn1(x) + x = self.activation(x) + x = self.conv2d1(x) + if self.d_spectral_norm is False: + x = self.bn2(x) + x = self.activation(x) + x = self.conv2d2(x) + if self.downsample: + x = self.average_pooling(x) + + if self.downsample or self.ch_mismatch: + if self.d_spectral_norm is False: + x0 = self.bn0(x0) + x0 = self.conv2d0(x0) + if self.downsample: + x0 = self.average_pooling(x0) + + out = x + x0 + return out + + +class Discriminator(nn.Module): + """Discriminator.""" + + def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, + hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, initialize, D_depth, mixed_precision): + super(Discriminator, self).__init__() + d_in_dims_collection = {"32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], + "64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8], + "128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16], + "256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16], + "512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]} + + d_out_dims_collection = {"32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], + "64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16], + "128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16], + "256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16], + "512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]} + + d_down = {"32": [True, True, False, False], + "64": [True, True, True, True, False], + "128": [True, True, True, True, True, False], + "256": [True, True, True, True, True, True, False], + "512": [True, True, True, True, True, True, True, False]} + + self.nonlinear_embed = nonlinear_embed + self.normalize_embed = normalize_embed + self.conditional_strategy = conditional_strategy + self.mixed_precision = mixed_precision + + self.in_dims = d_in_dims_collection[str(img_size)] + self.out_dims = d_out_dims_collection[str(img_size)] + down = d_down[str(img_size)] + + self.blocks = [] + for index in range(len(self.in_dims)): + if index == 0: + self.blocks += [[DiscOptBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + d_spectral_norm=d_spectral_norm, + activation_fn=activation_fn)]] + else: + self.blocks += [[DiscBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + d_spectral_norm=d_spectral_norm, + activation_fn=activation_fn, + downsample=down[index])]] + + if index + 1 == attention_after_nth_dis_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], d_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if d_spectral_norm: + self.linear1 = snlinear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = snlinear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = sn_embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = sn_embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = snlinear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + else: + self.linear1 = linear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = linear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = linear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + def forward(self, x, label, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + h = x + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + h = self.activation(h) + h = torch.sum(h, dim=[2, 3]) + + if self.conditional_strategy == 'no': + authen_output = torch.squeeze(self.linear1(h)) + return authen_output + + elif self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + authen_output = torch.squeeze(self.linear1(h)) + cls_proxy = self.embedding(label) + cls_embed = self.linear2(h) + if self.nonlinear_embed: + cls_embed = self.linear3(self.activation(cls_embed)) + if self.normalize_embed: + cls_proxy = F.normalize(cls_proxy, dim=1) + cls_embed = F.normalize(cls_embed, dim=1) + return cls_proxy, cls_embed, authen_output + + elif self.conditional_strategy == 'ProjGAN': + authen_output = torch.squeeze(self.linear1(h)) + proj = torch.sum(torch.mul(self.embedding(label), h), 1) + return proj + authen_output + + elif self.conditional_strategy == 'ACGAN': + authen_output = torch.squeeze(self.linear1(h)) + cls_output = self.linear4(h) + return cls_output, authen_output + + else: + raise NotImplementedError diff --git a/pytorch_pretrained_gans/StudioGAN/models/big_resnet_deep.py b/pytorch_pretrained_gans/StudioGAN/models/big_resnet_deep.py new file mode 100644 index 0000000..5adec4c --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/models/big_resnet_deep.py @@ -0,0 +1,382 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# models/big_resnet_deep.py + + +from ..utils.model_ops import * +from ..utils.misc import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, g_spectral_norm, activation_fn, conditional_bn, z_dims_after_concat, + upsample, channel_ratio=4): + super(GenBlock, self).__init__() + self.conditional_bn = conditional_bn + self.in_channels, self.out_channels = in_channels, out_channels + self.upsample = upsample + self.hidden_channels = self.in_channels//channel_ratio + + if self.conditional_bn: + self.bn1 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=in_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + self.bn2 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=self.hidden_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + self.bn3 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=self.hidden_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + self.bn4 = ConditionalBatchNorm2d_for_skip_and_shared(num_features=self.hidden_channels, z_dims_after_concat=z_dims_after_concat, + spectral_norm=g_spectral_norm) + else: + self.bn1 = batchnorm_2d(in_features=in_channels) + self.bn2 = batchnorm_2d(in_features=self.hidden_channels) + self.bn3 = batchnorm_2d(in_features=self.hidden_channels) + self.bn4 = batchnorm_2d(in_features=self.hidden_channels) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=self.hidden_channels, kernel_size=1, stride=1, padding=0) + self.conv2d2 = snconv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d3 = snconv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d4 = snconv2d(in_channels=self.hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=self.hidden_channels, kernel_size=1, stride=1, padding=0) + self.conv2d2 = conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d3 = conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d4 = conv2d(in_channels=self.hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) + + + def forward(self, x, label): + if self.in_channels != self.out_channels: + x0 = x[:, :self.out_channels] + else: + x0 = x + + x = self.conv2d1(self.activation(self.bn1(x, label))) + x = self.activation(self.bn2(x, label)) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample + x = self.conv2d2(x) + x = self.conv2d3(self.activation(self.bn3(x, label))) + x = self.conv2d4(self.activation(self.bn4(x, label))) + + if self.upsample: + x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample + out = x + x0 + return out + + +class Generator(nn.Module): + """Generator.""" + def __init__(self, z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, + conditional_strategy, num_classes, initialize, G_depth, mixed_precision): + super(Generator, self).__init__() + g_in_dims_collection = {"32": [g_conv_dim*4, g_conv_dim*4, g_conv_dim*4], + "64": [g_conv_dim*16, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2], + "128": [g_conv_dim*16, g_conv_dim*16, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2], + "256": [g_conv_dim*16, g_conv_dim*16, g_conv_dim*8, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2], + "512": [g_conv_dim*16, g_conv_dim*16, g_conv_dim*8, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2, g_conv_dim]} + + g_out_dims_collection = {"32": [g_conv_dim*4, g_conv_dim*4, g_conv_dim*4], + "64": [g_conv_dim*8, g_conv_dim*4, g_conv_dim*2, g_conv_dim], + "128": [g_conv_dim*16, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2, g_conv_dim], + "256": [g_conv_dim*16, g_conv_dim*8, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2, g_conv_dim], + "512": [g_conv_dim*16, g_conv_dim*8, g_conv_dim*8, g_conv_dim*4, g_conv_dim*2, g_conv_dim, g_conv_dim]} + bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4} + + self.z_dim = z_dim + self.shared_dim = shared_dim + self.num_classes = num_classes + self.mixed_precision = mixed_precision + conditional_bn = True if conditional_strategy in ["ACGAN", "ProjGAN", "ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"] else False + + self.in_dims = g_in_dims_collection[str(img_size)] + self.out_dims = g_out_dims_collection[str(img_size)] + self.bottom = bottom_collection[str(img_size)] + self.n_blocks = len(self.in_dims) + self.z_dims_after_concat = self.z_dim + self.shared_dim + + if g_spectral_norm: + self.linear0 = snlinear(in_features=self.z_dims_after_concat, out_features=self.in_dims[0]*self.bottom*self.bottom) + else: + self.linear0 = linear(in_features=self.z_dims_after_concat, out_features=self.in_dims[0]*self.bottom*self.bottom) + + self.shared = embedding(self.num_classes, self.shared_dim) + + self.blocks = [] + for index in range(self.n_blocks): + self.blocks += [[GenBlock(in_channels=self.in_dims[index], + out_channels=self.in_dims[index] if g_index == 0 else self.out_dims[index], + g_spectral_norm=g_spectral_norm, + activation_fn=activation_fn, + conditional_bn=conditional_bn, + z_dims_after_concat=self.z_dims_after_concat, + upsample=True if g_index == (G_depth-1) else False)] + for g_index in range(G_depth)] + + if index+1 == attention_after_nth_gen_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], g_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + self.bn4 = batchnorm_2d(in_features=self.out_dims[-1]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d5 = snconv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + else: + self.conv2d5 = conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + + self.tanh = nn.Tanh() + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + def forward(self, z, label, shared_label=None, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + if shared_label is None: + shared_label = self.shared(label) + else: + pass + z = torch.cat([shared_label, z], 1) + + act = self.linear0(z) + act = act.view(-1, self.in_dims[0], self.bottom, self.bottom) + counter = 0 + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + if isinstance(block, Self_Attn): + act = block(act) + else: + act = block(act, z) + counter +=1 + + act = self.bn4(act) + act = self.activation(act) + act = self.conv2d5(act) + out = self.tanh(act) + return out + + +class DiscBlock(nn.Module): + def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn, downsample=True, channel_ratio=4): + super(DiscBlock, self).__init__() + self.downsample = downsample + self.d_spectral_norm = d_spectral_norm + hidden_channels = out_channels//channel_ratio + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + raise NotImplementedError + + if self.d_spectral_norm: + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0) + self.conv2d2 = snconv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d3 = snconv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d4 = snconv2d(in_channels=hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0) + self.conv2d2 = conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d3 = conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1) + self.conv2d4 = conv2d(in_channels=hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) + + self.learnable_sc = True if (in_channels != out_channels) else False + if self.learnable_sc: + if self.d_spectral_norm: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels-in_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels-in_channels, kernel_size=1, stride=1, padding=0) + + if self.downsample: + self.average_pooling = nn.AvgPool2d(2) + + + def forward(self, x): + x0 = x + + x = self.activation(x) + x = self.conv2d1(x) + + x = self.conv2d2(self.activation(x)) + x = self.conv2d3(self.activation(x)) + x = self.activation(x) + + if self.downsample: + x = self.average_pooling(x) + + x = self.conv2d4(x) + + if self.downsample: + x0 = self.average_pooling(x0) + if self.learnable_sc: + x0 = torch.cat([x0, self.conv2d0(x0)], 1) + + out = x + x0 + return out + + +class Discriminator(nn.Module): + """Discriminator.""" + def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, + hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, initialize, D_depth, mixed_precision): + super(Discriminator, self).__init__() + d_in_dims_collection = {"32": [3] + [d_conv_dim*2, d_conv_dim*2, d_conv_dim*2], + "64": [3] +[d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8], + "128": [3] +[d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*16], + "256": [3] +[d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*8, d_conv_dim*16], + "512": [3] +[d_conv_dim, d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*8, d_conv_dim*16]} + + d_out_dims_collection = {"32": [d_conv_dim*2, d_conv_dim*2, d_conv_dim*2, d_conv_dim*2], + "64": [d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*16], + "128": [d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*16, d_conv_dim*16], + "256": [d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*8, d_conv_dim*16, d_conv_dim*16], + "512": [d_conv_dim, d_conv_dim, d_conv_dim*2, d_conv_dim*4, d_conv_dim*8, d_conv_dim*8, d_conv_dim*16, d_conv_dim*16]} + + d_down = {"32": [False, True, True, True, True], + "64": [False, True, True, True, True, True], + "128": [False, True, True, True, True, True], + "256": [False, True, True, True, True, True, True], + "512": [False, True, True, True, True, True, True, True]} + + self.nonlinear_embed = nonlinear_embed + self.normalize_embed = normalize_embed + self.conditional_strategy = conditional_strategy + self.mixed_precision = mixed_precision + + self.in_dims = d_in_dims_collection[str(img_size)] + self.out_dims = d_out_dims_collection[str(img_size)] + down = d_down[str(img_size)] + + if d_spectral_norm: + self.input_conv = snconv2d(in_channels=self.in_dims[0], out_channels=self.out_dims[0], kernel_size=3, stride=1, padding=1) + else: + self.input_conv = conv2d(in_channels=self.in_dims[0], out_channels=self.out_dims[0], kernel_size=3, stride=1, padding=1) + + self.blocks = [] + for index in range(len(self.in_dims)): + if index == 0: + self.blocks += [[self.input_conv]] + else: + self.blocks += [[DiscBlock(in_channels=self.in_dims[index] if d_index==0 else self.out_dims[index], + out_channels=self.out_dims[index], + d_spectral_norm=d_spectral_norm, + activation_fn=activation_fn, + downsample=True if down[index] and d_index==0 else False)] + for d_index in range(D_depth)] + + if index == attention_after_nth_dis_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], d_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if d_spectral_norm: + self.linear1 = snlinear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = snlinear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = sn_embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = sn_embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = snlinear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + else: + self.linear1 = linear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = linear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = linear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + + def forward(self, x, label, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + h = x + + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + h = self.activation(h) + h = torch.sum(h, dim=[2,3]) + + if self.conditional_strategy == 'no': + authen_output = torch.squeeze(self.linear1(h)) + return authen_output + + elif self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + authen_output = torch.squeeze(self.linear1(h)) + cls_proxy = self.embedding(label) + cls_embed = self.linear2(h) + if self.nonlinear_embed: + cls_embed = self.linear3(self.activation(cls_embed)) + if self.normalize_embed: + cls_proxy = F.normalize(cls_proxy, dim=1) + cls_embed = F.normalize(cls_embed, dim=1) + return cls_proxy, cls_embed, authen_output + + elif self.conditional_strategy == 'ProjGAN': + authen_output = torch.squeeze(self.linear1(h)) + proj = torch.sum(torch.mul(self.embedding(label), h), 1) + return proj + authen_output + + elif self.conditional_strategy == 'ACGAN': + authen_output = torch.squeeze(self.linear1(h)) + cls_output = self.linear4(h) + return cls_output, authen_output + + else: + raise NotImplementedError diff --git a/pytorch_pretrained_gans/StudioGAN/models/resnet.py b/pytorch_pretrained_gans/StudioGAN/models/resnet.py new file mode 100644 index 0000000..e58adac --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/models/resnet.py @@ -0,0 +1,422 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# models/resnet.py + + +from ..utils.model_ops import * +from ..utils.misc import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, g_spectral_norm, activation_fn, conditional_bn, num_classes): + super(GenBlock, self).__init__() + self.conditional_bn = conditional_bn + + if self.conditional_bn: + self.bn1 = ConditionalBatchNorm2d(num_features=in_channels, num_classes=num_classes, + spectral_norm=g_spectral_norm) + self.bn2 = ConditionalBatchNorm2d(num_features=out_channels, num_classes=num_classes, + spectral_norm=g_spectral_norm) + else: + self.bn1 = batchnorm_2d(in_features=in_channels) + self.bn2 = batchnorm_2d(in_features=out_channels) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + def forward(self, x, label): + x0 = x + + if self.conditional_bn: + x = self.bn1(x, label) + else: + x = self.bn1(x) + x = self.activation(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv2d1(x) + if self.conditional_bn: + x = self.bn2(x, label) + else: + x = self.bn2(x) + x = self.activation(x) + x = self.conv2d2(x) + + x0 = F.interpolate(x0, scale_factor=2, mode='nearest') + x0 = self.conv2d0(x0) + + out = x + x0 + return out + + +class Generator(nn.Module): + """Generator.""" + + def __init__(self, z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, + conditional_strategy, num_classes, initialize, G_depth, mixed_precision): + super(Generator, self).__init__() + g_in_dims_collection = {"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], + "64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], + "512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]} + + g_out_dims_collection = {"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], + "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], + "512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]} + bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4} + + self.z_dim = z_dim + self.num_classes = num_classes + self.mixed_precision = mixed_precision + conditional_bn = True if conditional_strategy in [ + "ACGAN", "ProjGAN", "ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"] else False + + self.in_dims = g_in_dims_collection[str(img_size)] + self.out_dims = g_out_dims_collection[str(img_size)] + self.bottom = bottom_collection[str(img_size)] + + if g_spectral_norm: + self.linear0 = snlinear(in_features=self.z_dim, out_features=self.in_dims[0] * self.bottom * self.bottom) + else: + self.linear0 = linear(in_features=self.z_dim, out_features=self.in_dims[0] * self.bottom * self.bottom) + + self.blocks = [] + for index in range(len(self.in_dims)): + self.blocks += [[GenBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + g_spectral_norm=g_spectral_norm, + activation_fn=activation_fn, + conditional_bn=conditional_bn, + num_classes=self.num_classes)]] + + if index + 1 == attention_after_nth_gen_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], g_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + self.bn4 = batchnorm_2d(in_features=self.out_dims[-1]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if g_spectral_norm: + self.conv2d5 = snconv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + else: + self.conv2d5 = conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) + + self.tanh = nn.Tanh() + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + def forward(self, z, label, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + act = self.linear0(z) + act = act.view(-1, self.in_dims[0], self.bottom, self.bottom) + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + if isinstance(block, Self_Attn): + act = block(act) + else: + act = block(act, label) + act = self.bn4(act) + act = self.activation(act) + act = self.conv2d5(act) + out = self.tanh(act) + return out + + +class DiscOptBlock(nn.Module): + def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn): + super(DiscOptBlock, self).__init__() + self.d_spectral_norm = d_spectral_norm + + if d_spectral_norm: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + self.bn0 = batchnorm_2d(in_features=in_channels) + self.bn1 = batchnorm_2d(in_features=out_channels) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + self.average_pooling = nn.AvgPool2d(2) + + def forward(self, x): + x0 = x + + x = self.conv2d1(x) + if self.d_spectral_norm is False: + x = self.bn1(x) + x = self.activation(x) + x = self.conv2d2(x) + x = self.average_pooling(x) + + x0 = self.average_pooling(x0) + if self.d_spectral_norm is False: + x0 = self.bn0(x0) + x0 = self.conv2d0(x0) + + out = x + x0 + return out + + +class DiscBlock(nn.Module): + def __init__(self, in_channels, out_channels, d_spectral_norm, activation_fn, downsample=True): + super(DiscBlock, self).__init__() + self.d_spectral_norm = d_spectral_norm + self.downsample = downsample + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + self.ch_mismatch = False + if in_channels != out_channels: + self.ch_mismatch = True + + if d_spectral_norm: + if self.ch_mismatch or downsample: + self.conv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + else: + if self.ch_mismatch or downsample: + self.conv2d0 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, stride=1, padding=0) + self.conv2d1 = conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + self.conv2d2 = conv2d(in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=1, padding=1) + + if self.ch_mismatch or downsample: + self.bn0 = batchnorm_2d(in_features=in_channels) + self.bn1 = batchnorm_2d(in_features=in_channels) + self.bn2 = batchnorm_2d(in_features=out_channels) + + self.average_pooling = nn.AvgPool2d(2) + + def forward(self, x): + x0 = x + if self.d_spectral_norm is False: + x = self.bn1(x) + x = self.activation(x) + x = self.conv2d1(x) + if self.d_spectral_norm is False: + x = self.bn2(x) + x = self.activation(x) + x = self.conv2d2(x) + if self.downsample: + x = self.average_pooling(x) + + if self.downsample or self.ch_mismatch: + if self.d_spectral_norm is False: + x0 = self.bn0(x0) + x0 = self.conv2d0(x0) + if self.downsample: + x0 = self.average_pooling(x0) + + out = x + x0 + return out + + +class Discriminator(nn.Module): + """Discriminator.""" + + def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, + hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, initialize, D_depth, mixed_precision): + super(Discriminator, self).__init__() + d_in_dims_collection = {"32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], + "64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8], + "128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16], + "256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16], + "512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]} + + d_out_dims_collection = {"32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], + "64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16], + "128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16], + "256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16], + "512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]} + + d_down = {"32": [True, True, False, False], + "64": [True, True, True, True, False], + "128": [True, True, True, True, True, False], + "256": [True, True, True, True, True, True, False], + "512": [True, True, True, True, True, True, True, False]} + + self.nonlinear_embed = nonlinear_embed + self.normalize_embed = normalize_embed + self.conditional_strategy = conditional_strategy + self.mixed_precision = mixed_precision + + self.in_dims = d_in_dims_collection[str(img_size)] + self.out_dims = d_out_dims_collection[str(img_size)] + down = d_down[str(img_size)] + + self.blocks = [] + for index in range(len(self.in_dims)): + if index == 0: + self.blocks += [[DiscOptBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + d_spectral_norm=d_spectral_norm, + activation_fn=activation_fn)]] + else: + self.blocks += [[DiscBlock(in_channels=self.in_dims[index], + out_channels=self.out_dims[index], + d_spectral_norm=d_spectral_norm, + activation_fn=activation_fn, + downsample=down[index])]] + + if index + 1 == attention_after_nth_dis_block and attention is True: + self.blocks += [[Self_Attn(self.out_dims[index], d_spectral_norm)]] + + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + if activation_fn == "ReLU": + self.activation = nn.ReLU(inplace=True) + elif activation_fn == "Leaky_ReLU": + self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif activation_fn == "ELU": + self.activation = nn.ELU(alpha=1.0, inplace=True) + elif activation_fn == "GELU": + self.activation = nn.GELU() + else: + raise NotImplementedError + + if d_spectral_norm: + self.linear1 = snlinear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = snlinear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = snlinear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = sn_embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = sn_embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = snlinear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + else: + self.linear1 = linear(in_features=self.out_dims[-1], out_features=1) + if self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + self.linear2 = linear(in_features=self.out_dims[-1], out_features=hypersphere_dim) + if self.nonlinear_embed: + self.linear3 = linear(in_features=hypersphere_dim, out_features=hypersphere_dim) + self.embedding = embedding(num_classes, hypersphere_dim) + elif self.conditional_strategy == 'ProjGAN': + self.embedding = embedding(num_classes, self.out_dims[-1]) + elif self.conditional_strategy == 'ACGAN': + self.linear4 = linear(in_features=self.out_dims[-1], out_features=num_classes) + else: + pass + + # Weight init + if initialize is not False: + init_weights(self.modules, initialize) + + def forward(self, x, label, evaluation=False): + with torch.cuda.amp.autocast() if self.mixed_precision is True and evaluation is False else dummy_context_mgr() as mp: + h = x + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + h = self.activation(h) + h = torch.sum(h, dim=[2, 3]) + + if self.conditional_strategy == 'no': + authen_output = torch.squeeze(self.linear1(h)) + return authen_output + + elif self.conditional_strategy in ['ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN']: + authen_output = torch.squeeze(self.linear1(h)) + cls_proxy = self.embedding(label) + cls_embed = self.linear2(h) + if self.nonlinear_embed: + cls_embed = self.linear3(self.activation(cls_embed)) + if self.normalize_embed: + cls_proxy = F.normalize(cls_proxy, dim=1) + cls_embed = F.normalize(cls_embed, dim=1) + return cls_proxy, cls_embed, authen_output + + elif self.conditional_strategy == 'ProjGAN': + authen_output = torch.squeeze(self.linear1(h)) + proj = torch.sum(torch.mul(self.embedding(label), h), 1) + return authen_output + proj + + elif self.conditional_strategy == 'ACGAN': + authen_output = torch.squeeze(self.linear1(h)) + cls_output = self.linear4(h) + return cls_output, authen_output + + else: + raise NotImplementedError diff --git a/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm.py b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..3c83edd --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm.py @@ -0,0 +1,421 @@ +""" +-*- coding: utf-8 -*- +File : batchnorm.py +Author : Jiayuan Mao +Email : maojiayuan@gmail.com +Date : 27/01/2018 + +This file is part of Synchronized-BatchNorm-PyTorch. +https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +Distributed under MIT License. + +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm_reimpl.py b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000..738cc0a --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,99 @@ +""" +-*- coding: utf-8 -*- +File : batchnorm_reimpl.py +Author : Jiayuan Mao +Email : maojiayuan@gmail.com +Date : 27/01/2018 + +This file is part of Synchronized-BatchNorm-PyTorch. +https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +Distributed under MIT License. + +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/comm.py b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/comm.py new file mode 100644 index 0000000..7469635 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/comm.py @@ -0,0 +1,162 @@ +""" +-*- coding: utf-8 -*- +File : comm.py +Author : Jiayuan Mao +Email : maojiayuan@gmail.com +Date : 27/01/2018 + +This file is part of Synchronized-BatchNorm-PyTorch. +https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +Distributed under MIT License. + +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/replicate.py b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/replicate.py new file mode 100644 index 0000000..6dcd3e3 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/replicate.py @@ -0,0 +1,119 @@ +""" +-*- coding: utf-8 -*- +File : replicate.py +Author : Jiayuan Mao +Email : maojiayuan@gmail.com +Date : 27/01/2018 + +This file is part of Synchronized-BatchNorm-PyTorch. +https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +Distributed under MIT License. + +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/unittest.py b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/unittest.py new file mode 100644 index 0000000..c6ac864 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/unittest.py @@ -0,0 +1,54 @@ +""" +-*- coding: utf-8 -*- +File : unittest.py +Author : Jiayuan Mao +Email : maojiayuan@gmail.com +Date : 27/01/2018 + +This file is part of Synchronized-BatchNorm-PyTorch. +https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +Distributed under MIT License. + +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada.py b/pytorch_pretrained_gans/StudioGAN/utils/ada.py new file mode 100644 index 0000000..4d52b9d --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada.py @@ -0,0 +1,415 @@ +""" +MIT License +Copyright (c) 2019 Kim Seonghyeon +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import math +import numpy as np + +from .ada_op import upfirdn2d + +import torch +from torch.nn import functional as F + + +SYM6 = ( + 0.015404109327027373, + 0.0034907120842174702, + -0.11799011114819057, + -0.048311742585633, + 0.4910559419267466, + 0.787641141030194, + 0.3379294217276218, + -0.07263752278646252, + -0.021060292512300564, + 0.04472490177066578, + 0.0017677118642428036, + -0.007800708325034148, +) + + +def translate_mat(t_x, t_y): + batch = t_x.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + translate = torch.stack((t_x, t_y), 1) + mat[:, :2, 2] = translate + + return mat + + +def rotate_mat(theta): + batch = theta.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + sin_t = torch.sin(theta) + cos_t = torch.cos(theta) + rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) + mat[:, :2, :2] = rot + + return mat + + +def scale_mat(s_x, s_y): + batch = s_x.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + mat[:, 0, 0] = s_x + mat[:, 1, 1] = s_y + + return mat + + +def translate3d_mat(t_x, t_y, t_z): + batch = t_x.shape[0] + + mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) + translate = torch.stack((t_x, t_y, t_z), 1) + mat[:, :3, 3] = translate + + return mat + + +def rotate3d_mat(axis, theta): + batch = theta.shape[0] + + u_x, u_y, u_z = axis + + eye = torch.eye(3).unsqueeze(0) + cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) + outer = torch.tensor(axis) + outer = (outer.unsqueeze(1) * outer).unsqueeze(0) + + sin_t = torch.sin(theta).view(-1, 1, 1) + cos_t = torch.cos(theta).view(-1, 1, 1) + + rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer + + eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) + eye_4[:, :3, :3] = rot + + return eye_4 + + +def scale3d_mat(s_x, s_y, s_z): + batch = s_x.shape[0] + + mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) + mat[:, 0, 0] = s_x + mat[:, 1, 1] = s_y + mat[:, 2, 2] = s_z + + return mat + + +def luma_flip_mat(axis, i): + batch = i.shape[0] + + eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) + axis = torch.tensor(axis + (0,)) + flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) + + return eye - flip + + +def saturation_mat(axis, i): + batch = i.shape[0] + + eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) + axis = torch.tensor(axis + (0,)) + axis = torch.ger(axis, axis) + saturate = axis + (eye - axis) * i.view(-1, 1, 1) + + return saturate + + +def lognormal_sample(size, mean=0, std=1): + return torch.empty(size).log_normal_(mean=mean, std=std) + + +def category_sample(size, categories): + category = torch.tensor(categories) + sample = torch.randint(high=len(categories), size=(size,)) + + return category[sample] + + +def uniform_sample(size, low, high): + return torch.empty(size).uniform_(low, high) + + +def normal_sample(size, mean=0, std=1): + return torch.empty(size).normal_(mean, std) + + +def bernoulli_sample(size, p): + return torch.empty(size).bernoulli_(p) + + +def random_mat_apply(p, transform, prev, eye): + size = transform.shape[0] + select = bernoulli_sample(size, p).view(size, 1, 1) + select_transform = select * transform + (1 - select) * eye + + return select_transform @ prev + + +def sample_affine(p, size, height, width): + G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) + eye = G + + # flip + param = category_sample(size, (0, 1)) + Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) + G = random_mat_apply(p, Gc, G, eye) + # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') + + # 90 rotate + param = category_sample(size, (0, 3)) + Gc = rotate_mat(-math.pi / 2 * param) + G = random_mat_apply(p, Gc, G, eye) + # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') + + # integer translate + param = uniform_sample(size, -0.125, 0.125) + param_height = torch.round(param * height) / height + param_width = torch.round(param * width) / width + Gc = translate_mat(param_width, param_height) + G = random_mat_apply(p, Gc, G, eye) + # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') + + # isotropic scale + param = lognormal_sample(size, std=0.2 * math.log(2)) + Gc = scale_mat(param, param) + G = random_mat_apply(p, Gc, G, eye) + # print('isotropic scale', G, scale_mat(param, param), sep='\n') + + p_rot = 1 - math.sqrt(1 - p) + + # pre-rotate + param = uniform_sample(size, -math.pi, math.pi) + Gc = rotate_mat(-param) + G = random_mat_apply(p_rot, Gc, G, eye) + # print('pre-rotate', G, rotate_mat(-param), sep='\n') + + # anisotropic scale + param = lognormal_sample(size, std=0.2 * math.log(2)) + Gc = scale_mat(param, 1 / param) + G = random_mat_apply(p, Gc, G, eye) + # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') + + # post-rotate + param = uniform_sample(size, -math.pi, math.pi) + Gc = rotate_mat(-param) + G = random_mat_apply(p_rot, Gc, G, eye) + # print('post-rotate', G, rotate_mat(-param), sep='\n') + + # fractional translate + param = normal_sample(size, std=0.125) + Gc = translate_mat(param, param) + G = random_mat_apply(p, Gc, G, eye) + # print('fractional translate', G, translate_mat(param, param), sep='\n') + + return G + + +def sample_color(p, size): + C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) + eye = C + axis_val = 1 / math.sqrt(3) + axis = (axis_val, axis_val, axis_val) + + # brightness + param = normal_sample(size, std=0.2) + Cc = translate3d_mat(param, param, param) + C = random_mat_apply(p, Cc, C, eye) + + # contrast + param = lognormal_sample(size, std=0.5 * math.log(2)) + Cc = scale3d_mat(param, param, param) + C = random_mat_apply(p, Cc, C, eye) + + # luma flip + param = category_sample(size, (0, 1)) + Cc = luma_flip_mat(axis, param) + C = random_mat_apply(p, Cc, C, eye) + + # hue rotation + param = uniform_sample(size, -math.pi, math.pi) + Cc = rotate3d_mat(axis, param) + C = random_mat_apply(p, Cc, C, eye) + + # saturation + param = lognormal_sample(size, std=1 * math.log(2)) + Cc = saturation_mat(axis, param) + C = random_mat_apply(p, Cc, C, eye) + + return C + + +def make_grid(shape, x0, x1, y0, y1, device): + n, c, h, w = shape + grid = torch.empty(n, h, w, 3, device=device) + grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) + grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) + grid[:, :, :, 2] = 1 + + return grid + + +def affine_grid(grid, mat): + n, h, w, _ = grid.shape + return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) + + +def get_padding(G, height, width, pad_k): + extreme = ( + G[:, :2, :] + @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t() + ) + + size = torch.tensor((width, height)) + + pad_low = ( + ((extreme.min(-1).values + 1) * size) + .clamp(max=0) + .abs() + .ceil() + .max(0) + .values.to(torch.int64) + .tolist() + ) + pad_high = ( + (extreme.max(-1).values * size - size) + .clamp(min=0) + .ceil() + .max(0) + .values.to(torch.int64) + .tolist() + ) + + h_pad_lth = np.clip([pad_low[0], pad_high[0]], a_max=height - pad_k - 1, a_min=-100000) + w_pad_lth = np.clip([pad_low[1], pad_high[1]], a_max=width - pad_k - 1, a_min=-100000) + + return int(h_pad_lth[0]), int(h_pad_lth[1]), int(w_pad_lth[0]), int(w_pad_lth[1]) + + +def try_sample_affine_and_pad(img, p, pad_k, G=None): + batch, _, height, width = img.shape + + G_try = G + + if G is None: + G_try = sample_affine(p, batch, height, width) + + pad_x1, pad_x2, pad_y1, pad_y2 = get_padding( + torch.inverse(G_try), height, width, pad_k, + ) + + img_pad = F.pad( + img, + (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k), + mode="reflect", + ) + + return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) + + +def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): + kernel = antialiasing_kernel + len_k = len(kernel) + pad_k = (len_k + 1) // 2 + + kernel = torch.as_tensor(kernel) + kernel = torch.ger(kernel, kernel).to(img) + kernel_flip = torch.flip(kernel, (0, 1)) + + img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( + img, p, pad_k, G + ) + + p_ux1 = pad_x1 + p_ux2 = pad_x2 + 1 + p_uy1 = pad_y1 + p_uy2 = pad_y2 + 1 + w_p = img_pad.shape[3] - len_k + 1 + h_p = img_pad.shape[2] - len_k + 1 + h_o = img.shape[2] + w_o = img.shape[3] + img_2x = upfirdn2d(img_pad, kernel_flip, up=2) + + grid = make_grid( + img_2x.shape, + -2 * p_ux1 / w_o - 1, + 2 * (w_p - p_ux1) / w_o - 1, + -2 * p_uy1 / h_o - 1, + 2 * (h_p - p_uy1) / h_o - 1, + device=img_2x.device, + ).to(img_2x) + grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x)) + grid = grid * torch.tensor( + [w_o / w_p, h_o / h_p], device=grid.device + ) + torch.tensor( + [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device + ) + + img_affine = F.grid_sample( + img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros" + ) + + img_down = upfirdn2d(img_affine, kernel, down=2) + + end_y = -pad_y2 - 1 + if end_y == 0: + end_y = img_down.shape[2] + + end_x = -pad_x2 - 1 + if end_x == 0: + end_x = img_down.shape[3] + + img = img_down[:, :, pad_y1:end_y, pad_x1:end_x] + + return img, G + + +def apply_color(img, mat): + batch = img.shape[0] + img = img.permute(0, 2, 3, 1) + mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) + mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) + img = img @ mat_mul + mat_add + img = img.permute(0, 3, 1, 2) + + return img + + +def random_apply_color(img, p, C=None): + if C is None: + C = sample_color(p, img.shape[0]) + + img = apply_color(img, C.to(img)) + + return img, C + + +def augment(img, p, transform_matrix=(None, None)): + img, G = random_apply_affine(img, p, transform_matrix[0]) + img, C = random_apply_color(img, p, transform_matrix[1]) + + return img, (G, C) diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/__init__.py b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/__init__.py new file mode 100755 index 0000000..d0918d9 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_act.py b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_act.py new file mode 100755 index 0000000..1777eb2 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_act.py @@ -0,0 +1,122 @@ +""" +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +fused = load( + "fused", + sources=[ + os.path.join(module_path, "fused_bias_act.cpp"), + os.path.join(module_path, "fused_bias_act_kernel.cu"), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + if input.device.type == "cpu": + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act.cpp b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act.cpp new file mode 100755 index 0000000..8b8759e --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act.cpp @@ -0,0 +1,46 @@ +/* +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + + +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act_kernel.cu b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act_kernel.cu new file mode 100755 index 0000000..1c095f4 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.cpp b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.cpp new file mode 100755 index 0000000..8a319c3 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.cpp @@ -0,0 +1,48 @@ +/* +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + + +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.py b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.py new file mode 100755 index 0000000..c17c3f6 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.py @@ -0,0 +1,225 @@ +""" +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import os + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d_kernel.cu b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d_kernel.cu new file mode 100755 index 0000000..a22510f --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/pytorch_pretrained_gans/StudioGAN/utils/biggan_utils.py b/pytorch_pretrained_gans/StudioGAN/utils/biggan_utils.py new file mode 100644 index 0000000..7301ea1 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/biggan_utils.py @@ -0,0 +1,105 @@ +""" +this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch + +MIT License + +Copyright (c) 2019 Andy Brock + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import random + +from .sample import sample_latents + +import torch + + +class ema(object): + def __init__(self, source, target, decay=0.9999, start_itr=0): + self.source = source + self.target = target + self.decay = decay + # Optional parameter indicating what iteration to start the decay at + self.start_itr = start_itr + # Initialize target's params to be source's + self.source_dict = self.source.state_dict() + self.target_dict = self.target.state_dict() + print('Initializing EMA parameters to be source parameters...') + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.source_dict[key].data) + self.target_dict[key].requires_grad = False + + def update(self, itr=None): + # If an iteration counter is provided and itr is less than the start itr, + # peg the ema weights to the underlying weights. + if itr >= 0 and itr < self.start_itr: + decay = 0.0 + else: + decay = self.decay + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.target_dict[key].data * decay + + self.source_dict[key].data * (1 - decay)) + + +class ema_DP_SyncBN(object): + def __init__(self, source, target, decay=0.9999, start_itr=0): + self.source = source + self.target = target + self.decay = decay + self.start_itr = start_itr + # Initialize target's params to be source's + print('Initializing EMA parameters to be source parameters...') + for key in self.source.state_dict(): + self.target.state_dict()[key].data.copy_(self.source.state_dict()[key].data) + self.target.state_dict()[key].requires_grad = False + + def update(self, itr=None): + # If an iteration counter is provided and itr is less than the start itr, + # peg the ema weights to the underlying weights. + if itr >= 0 and itr < self.start_itr: + decay = 0.0 + else: + decay = self.decay + + for key in self.source.state_dict(): + data = self.target.state_dict()[key].data * decay + \ + self.source.state_dict()[key].detach().data * (1. - decay) + self.target.state_dict()[key].data.copy_(data) + + +def ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes, and not in the blacklist + if len(param.shape) < 2 or any([param is item for item in blacklist]): + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + * (1. - torch.eye(w.shape[0], device=w.device)), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) +def interp(x0, x1, num_midpoints): + lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) + return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) diff --git a/pytorch_pretrained_gans/StudioGAN/utils/cr_diff_aug.py b/pytorch_pretrained_gans/StudioGAN/utils/cr_diff_aug.py new file mode 100644 index 0000000..39e174e --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/cr_diff_aug.py @@ -0,0 +1,50 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/cr_diff_aug.py + + +import random + +import torch +import torch.nn.functional as F + + + +def CR_DiffAug(x, flip=True, translation=True): + if flip: + x = random_flip(x, 0.5) + if translation: + x = random_translation(x, 1/8) + if flip or translation: + x = x.contiguous() + return x + + +def random_flip(x, p): + x_out = x.clone() + n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) + flip_mask = flip_prob < p + flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) + x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) + return x_out + + +def random_translation(x, ratio): + max_t_x, max_t_y = int(x.shape[2]*ratio), int(x.shape[3]*ratio) + t_x = torch.randint(-max_t_x, max_t_x + 1, size = [x.shape[0], 1, 1], device=x.device) + t_y = torch.randint(-max_t_y, max_t_y + 1, size = [x.shape[0], 1, 1], device=x.device) + + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.shape[0], dtype=torch.long, device=x.device), + torch.arange(x.shape[2], dtype=torch.long, device=x.device), + torch.arange(x.shape[3], dtype=torch.long, device=x.device), + ) + + grid_x = (grid_x + t_x) + max_t_x + grid_y = (grid_y + t_y) + max_t_y + x_pad = F.pad(input=x, pad=[max_t_y, max_t_y, max_t_x, max_t_x], mode='reflect') + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x diff --git a/pytorch_pretrained_gans/StudioGAN/utils/diff_aug.py b/pytorch_pretrained_gans/StudioGAN/utils/diff_aug.py new file mode 100644 index 0000000..e700abb --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/diff_aug.py @@ -0,0 +1,105 @@ +""" +Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import torch +import torch.nn.functional as F + + + +### Differentiable Augmentation for Data-Efficient GAN Training (https://arxiv.org/abs/2006.10738) +### Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +### https://github.com/mit-han-lab/data-efficient-gans + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} diff --git a/pytorch_pretrained_gans/StudioGAN/utils/load_checkpoint.py b/pytorch_pretrained_gans/StudioGAN/utils/load_checkpoint.py new file mode 100644 index 0000000..0c7fd88 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/load_checkpoint.py @@ -0,0 +1,38 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/load_checkpoint.py + + +import os + +import torch + + + +def load_checkpoint(model, optimizer, filename, metric=False, ema=False): + start_step = 0 + if ema: + checkpoint = torch.load(filename) + model.load_state_dict(checkpoint['state_dict']) + return model + else: + checkpoint = torch.load(filename) + seed = checkpoint['seed'] + run_name = checkpoint['run_name'] + start_step = checkpoint['step'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + ada_p = checkpoint['ada_p'] + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + + if metric: + best_step = checkpoint['best_step'] + best_fid = checkpoint['best_fid'] + best_fid_checkpoint_path = checkpoint['best_fid_checkpoint_path'] + return model, optimizer, seed, run_name, start_step, ada_p, best_step, best_fid, best_fid_checkpoint_path + return model, optimizer, seed, run_name, start_step, ada_p diff --git a/pytorch_pretrained_gans/StudioGAN/utils/log.py b/pytorch_pretrained_gans/StudioGAN/utils/log.py new file mode 100644 index 0000000..4e4b8e5 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/log.py @@ -0,0 +1,54 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/log.py + + +import json +import os +import logging +from os.path import dirname, abspath, exists, join +from datetime import datetime + + + +def make_run_name(format, framework, phase): + return format.format( + framework=framework, + phase=phase, + timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + ) + + +def make_logger(run_name, log_output): + if log_output is not None: + run_name = log_output.split('/')[-1].split('.')[0] + logger = logging.getLogger(run_name) + logger.propagate = False + log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log') + + log_dir = dirname(abspath(log_filepath)) + if not exists(log_dir): + os.makedirs(log_dir) + + if not logger.handlers: # execute only if logger doesn't already exist + file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8') + stream_handler = logging.StreamHandler(os.sys.stdout) + + formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + + file_handler.setFormatter(formatter) + stream_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) + return logger + + +def make_checkpoint_dir(checkpoint_dir, run_name): + checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name) + if not exists(abspath(checkpoint_dir)): + os.makedirs(checkpoint_dir) + return checkpoint_dir diff --git a/pytorch_pretrained_gans/StudioGAN/utils/losses.py b/pytorch_pretrained_gans/StudioGAN/utils/losses.py new file mode 100644 index 0000000..af0a7b0 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/losses.py @@ -0,0 +1,316 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/losses.py + + +import numpy as np + +from .model_ops import snlinear, linear + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import DataParallel +from torch import autograd + + +# DCGAN loss +def loss_dcgan_dis(dis_out_real, dis_out_fake): + device = dis_out_real.get_device() + ones = torch.ones_like(dis_out_real, device=device, requires_grad=False) + dis_loss = -torch.mean(nn.LogSigmoid()(dis_out_real) + nn.LogSigmoid()(ones - dis_out_fake)) + return dis_loss + + +def loss_dcgan_gen(gen_out_fake): + return -torch.mean(nn.LogSigmoid()(gen_out_fake)) + + +def loss_lsgan_dis(dis_out_real, dis_out_fake): + dis_loss = 0.5 * (dis_out_real - torch.ones_like(dis_out_real))**2 + 0.5 * (dis_out_fake)**2 + return dis_loss.mean() + + +def loss_lsgan_gen(dis_out_fake): + gen_loss = 0.5 * (dis_out_fake - torch.ones_like(dis_out_fake))**2 + return gen_loss.mean() + + +def loss_hinge_dis(dis_out_real, dis_out_fake): + return torch.mean(F.relu(1. - dis_out_real)) + torch.mean(F.relu(1. + dis_out_fake)) + + +def loss_hinge_gen(gen_out_fake): + return -torch.mean(gen_out_fake) + + +def loss_wgan_dis(dis_out_real, dis_out_fake): + return torch.mean(dis_out_fake - dis_out_real) + + +def loss_wgan_gen(gen_out_fake): + return -torch.mean(gen_out_fake) + + +def latent_optimise(zs, fake_labels, gen_model, dis_model, conditional_strategy, latent_op_step, latent_op_rate, + latent_op_alpha, latent_op_beta, trans_cost, default_device): + batch_size = zs.shape[0] + for step in range(latent_op_step): + drop_mask = (torch.FloatTensor(batch_size, 1).uniform_() > 1 - latent_op_rate).to(default_device) + z_gradients, z_gradients_norm = calc_derv( + zs, fake_labels, dis_model, conditional_strategy, default_device, gen_model) + delta_z = latent_op_alpha * z_gradients / (latent_op_beta + z_gradients_norm) + zs = torch.clamp(zs + drop_mask * delta_z, -1.0, 1.0) + + if trans_cost: + if step == 0: + transport_cost = (delta_z.norm(2, dim=1)**2).mean() + else: + transport_cost += (delta_z.norm(2, dim=1)**2).mean() + return zs, trans_cost + else: + return zs + + +def set_temperature(conditional_strategy, tempering_type, start_temperature, end_temperature, step_count, tempering_step, total_step): + if conditional_strategy == 'ContraGAN': + if tempering_type == 'continuous': + t = start_temperature + step_count * (end_temperature - start_temperature) / total_step + elif tempering_type == 'discrete': + tempering_interval = total_step // (tempering_step + 1) + t = start_temperature + \ + (step_count // tempering_interval) * (end_temperature - start_temperature) / tempering_step + else: + t = start_temperature + else: + t = 'no' + return t + + +class Cross_Entropy_loss(torch.nn.Module): + def __init__(self, in_features, out_features, spectral_norm=True): + super(Cross_Entropy_loss, self).__init__() + + if spectral_norm: + self.layer = snlinear(in_features=in_features, out_features=out_features, bias=True) + else: + self.layer = linear(in_features=in_features, out_features=out_features, bias=True) + self.ce_loss = torch.nn.CrossEntropyLoss() + + def forward(self, embeds, labels): + logits = self.layer(embeds) + return self.ce_loss(logits, labels) + + +class Conditional_Contrastive_loss(torch.nn.Module): + def __init__(self, device, batch_size, pos_collected_numerator): + super(Conditional_Contrastive_loss, self).__init__() + self.device = device + self.batch_size = batch_size + self.pos_collected_numerator = pos_collected_numerator + self.calculate_similarity_matrix = self._calculate_similarity_matrix() + self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + + def _calculate_similarity_matrix(self): + return self._cosine_simililarity_matrix + + def remove_diag(self, M): + h, w = M.shape + assert h == w, "h and w should be same" + mask = np.ones((h, w)) - np.eye(h) + mask = torch.from_numpy(mask) + mask = (mask).type(torch.bool).to(self.device) + return M[mask].view(h, -1) + + def _cosine_simililarity_matrix(self, x, y): + v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, inst_embed, proxy, negative_mask, labels, temperature, margin): + similarity_matrix = self.calculate_similarity_matrix(inst_embed, inst_embed) + instance_zone = torch.exp((self.remove_diag(similarity_matrix) - margin) / temperature) + + inst2proxy_positive = torch.exp((self.cosine_similarity(inst_embed, proxy) - margin) / temperature) + if self.pos_collected_numerator: + mask_4_remove_negatives = negative_mask[labels] + mask_4_remove_negatives = self.remove_diag(mask_4_remove_negatives) + inst2inst_positives = instance_zone * mask_4_remove_negatives + + numerator = inst2proxy_positive + inst2inst_positives.sum(dim=1) + else: + numerator = inst2proxy_positive + + denomerator = torch.cat([torch.unsqueeze(inst2proxy_positive, dim=1), instance_zone], dim=1).sum(dim=1) + criterion = -torch.log(temperature * (numerator / denomerator)).mean() + return criterion + + +class Proxy_NCA_loss(torch.nn.Module): + def __init__(self, device, embedding_layer, num_classes, batch_size): + super(Proxy_NCA_loss, self).__init__() + self.device = device + self.embedding_layer = embedding_layer + self.num_classes = num_classes + self.batch_size = batch_size + self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + + def _get_positive_proxy_mask(self, labels): + labels = labels.detach().cpu().numpy() + rvs_one_hot_target = np.ones([self.num_classes, self.num_classes]) - np.eye(self.num_classes) + rvs_one_hot_target = rvs_one_hot_target[labels] + mask = torch.from_numpy((rvs_one_hot_target)).type(torch.bool) + return mask.to(self.device) + + def forward(self, inst_embed, proxy, labels): + all_labels = torch.tensor([c for c in range(self.num_classes)]).type(torch.long).to(self.device) + positive_proxy_mask = self._get_positive_proxy_mask(labels) + negative_proxies = torch.exp(torch.mm(inst_embed, self.embedding_layer(all_labels).T)) * positive_proxy_mask + + inst2proxy_positive = torch.exp(self.cosine_similarity(inst_embed, proxy)) + numerator = inst2proxy_positive + denomerator = negative_proxies.sum(dim=1) + criterion = -torch.log(numerator / denomerator).mean() + return criterion + + +class NT_Xent_loss(torch.nn.Module): + def __init__(self, device, batch_size, use_cosine_similarity=True): + super(NT_Xent_loss, self).__init__() + self.device = device + self.batch_size = batch_size + self.softmax = torch.nn.Softmax(dim=-1) + self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) + self.similarity_function = self._get_similarity_function(use_cosine_similarity) + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_similarity_function(self, use_cosine_similarity): + if use_cosine_similarity: + self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + return self._cosine_simililarity + else: + return self._dot_simililarity + + def _get_correlated_mask(self): + diag = np.eye(2 * self.batch_size) + l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) + l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) + mask = torch.from_numpy((diag + l1 + l2)) + mask = (1 - mask).type(torch.bool) + return mask.to(self.device) + + @staticmethod + def _dot_simililarity(x, y): + v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v + + def _cosine_simililarity(self, x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, zis, zjs, temperature): + representations = torch.cat([zjs, zis], dim=0) + + similarity_matrix = self.similarity_function(representations, representations) + + # filter out the scores from the positive samples + l_pos = torch.diag(similarity_matrix, self.batch_size) + r_pos = torch.diag(similarity_matrix, -self.batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) + + negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) + + logits = torch.cat((positives, negatives), dim=1) + logits /= temperature + + labels = torch.zeros(2 * self.batch_size).to(self.device).long() + loss = self.criterion(logits, labels) + return loss / (2 * self.batch_size) + + +def calc_derv4gp(netD, conditional_strategy, real_data, fake_data, real_labels, device): + batch_size, c, h, w = real_data.shape + alpha = torch.rand(batch_size, 1) + alpha = alpha.expand(batch_size, real_data.nelement() // batch_size).contiguous().view(batch_size, c, h, w) + alpha = alpha.to(device) + + real_data = real_data.to(device) + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + interpolates = interpolates.to(device) + interpolates = autograd.Variable(interpolates, requires_grad=True) + + if conditional_strategy in ['ContraGAN', "Proxy_NCA_GAN", "NT_Xent_GAN"]: + _, _, disc_interpolates = netD(interpolates, real_labels) + elif conditional_strategy in ['ProjGAN', 'no']: + disc_interpolates = netD(interpolates, real_labels) + elif conditional_strategy == 'ACGAN': + _, disc_interpolates = netD(interpolates, real_labels) + else: + raise NotImplementedError + + gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradients = gradients.view(gradients.size(0), -1) + + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +def calc_derv4dra(netD, conditional_strategy, real_data, real_labels, device): + batch_size, c, h, w = real_data.shape + alpha = torch.rand(batch_size, 1, 1, 1) + alpha = alpha.to(device) + + real_data = real_data.to(device) + differences = 0.5 * real_data.std() * torch.rand(real_data.size()).to(device) + + interpolates = real_data + (alpha * differences) + interpolates = interpolates.to(device) + interpolates = autograd.Variable(interpolates, requires_grad=True) + + if conditional_strategy in ['ContraGAN', "Proxy_NCA_GAN", "NT_Xent_GAN"]: + _, _, disc_interpolates = netD(interpolates, real_labels) + elif conditional_strategy in ['ProjGAN', 'no']: + disc_interpolates = netD(interpolates, real_labels) + elif conditional_strategy == 'ACGAN': + _, disc_interpolates = netD(interpolates, real_labels) + else: + raise NotImplementedError + + gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradients = gradients.view(gradients.size(0), -1) + + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +def calc_derv(inputs, labels, netD, conditional_strategy, device, netG=None): + zs = autograd.Variable(inputs, requires_grad=True) + fake_images = netG(zs, labels) + + if conditional_strategy in ['ContraGAN', "Proxy_NCA_GAN", "NT_Xent_GAN"]: + _, _, dis_out_fake = netD(fake_images, labels) + elif conditional_strategy in ['ProjGAN', 'no']: + dis_out_fake = netD(fake_images, labels) + elif conditional_strategy == 'ACGAN': + _, dis_out_fake = netD(fake_images, labels) + else: + raise NotImplementedError + + gradients = autograd.grad(outputs=dis_out_fake, inputs=zs, + grad_outputs=torch.ones(dis_out_fake.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients_norm = torch.unsqueeze((gradients.norm(2, dim=1) ** 2), dim=1) + return gradients, gradients_norm diff --git a/pytorch_pretrained_gans/StudioGAN/utils/make_hdf5.py b/pytorch_pretrained_gans/StudioGAN/utils/make_hdf5.py new file mode 100644 index 0000000..1113996 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/make_hdf5.py @@ -0,0 +1,93 @@ +""" +this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch + +MIT License + +Copyright (c) 2019 Andy Brock +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + + +import os +import sys +import h5py as h5 +import numpy as np +import PIL +from argparse import ArgumentParser +from tqdm import tqdm, trange + +from data_utils.load_dataset import LoadDataset + +import torch +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + + + +def make_hdf5(model_config, train_config, mode): + if 'hdf5' in model_config['dataset_name']: + raise ValueError('Reading from an HDF5 file which you will probably be ' + 'about to overwrite! Override this error only if you know ' + 'what you''re doing!') + + file_name = '{dataset_name}_{size}_{mode}.hdf5'.format(dataset_name=model_config['dataset_name'], size=model_config['img_size'], mode=mode) + file_path = os.path.join(model_config['data_path'], file_name) + train = True if mode == "train" else False + + if os.path.isfile(file_path): + print("{file_name} exist!\nThe file are located in the {file_path}".format(file_name=file_name, file_path=file_path)) + else: + dataset = LoadDataset(model_config['dataset_name'], model_config['data_path'], train=train, download=True, resize_size=model_config['img_size'], + hdf5_path=None, random_flip=False) + + loader = DataLoader(dataset, + batch_size=model_config['batch_size4prcsing'], + shuffle=False, + pin_memory=False, + num_workers=train_config['num_workers'], + drop_last=False) + + print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (model_config['dataset_name'], + model_config['chunk_size'], + model_config['compression'])) + # Loop over loader + for i,(x,y) in enumerate(tqdm(loader)): + # Numpyify x, y + x = (255 * ((x + 1) / 2.0)).byte().numpy() + y = y.numpy() + # If we're on the first batch, prepare the hdf5 + if i==0: + with h5.File(file_path, 'w') as f: + print('Producing dataset of len %d' % len(loader.dataset)) + imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(len(loader.dataset), 3, + model_config['img_size'], model_config['img_size']), + chunks=(model_config['chunk_size'], 3, model_config['img_size'], model_config['img_size']), compression=model_config['compression']) + print('Image chunks chosen as ' + str(imgs_dset.chunks)) + imgs_dset[...] = x + + labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(loader.dataset),), + chunks=(model_config['chunk_size'],), compression=model_config['compression']) + print('Label chunks chosen as ' + str(labels_dset.chunks)) + labels_dset[...] = y + # Else append to the hdf5 + else: + with h5.File(file_path, 'a') as f: + f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) + f['imgs'][-x.shape[0]:] = x + f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) + f['labels'][-y.shape[0]:] = y + return file_path diff --git a/pytorch_pretrained_gans/StudioGAN/utils/misc.py b/pytorch_pretrained_gans/StudioGAN/utils/misc.py new file mode 100644 index 0000000..551a464 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/misc.py @@ -0,0 +1,601 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/misc.py + + +import numpy as np +import random +import math +import os +import sys +import shutil +import warnings +# import seaborn as sns +# import matplotlib.pyplot as plt +from os.path import dirname, abspath, exists, join +from scipy import linalg +from datetime import datetime +from tqdm import tqdm +from itertools import chain +from collections import defaultdict + +# from metrics.FID import generate_images +from .sample import sample_latents +from .losses import latent_optimise + +import torch +import torch.nn.functional as F +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel +from torchvision.utils import save_image + + +class dummy_context_mgr(): + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_value, traceback): + return False + + +class Adaptive_Augment(object): + def __init__(self, prev_ada_p, ada_target, ada_length, batch_size, rank): + self.prev_ada_p = prev_ada_p + self.ada_target = ada_target + self.ada_length = ada_length + self.batch_size = batch_size + self.rank = rank + + self.ada_aug_step = self.ada_target / self.ada_length + + def initialize(self): + self.ada_augment = torch.tensor([0.0, 0.0], device=self.rank) + if self.prev_ada_p is not None: + self.ada_aug_p = self.prev_ada_p + else: + self.ada_aug_p = 0.0 + return self.ada_aug_p + + def update(self, logits): + ada_aug_data = torch.tensor((torch.sign(logits).sum().item(), logits.shape[0]), device=self.rank) + self.ada_augment += ada_aug_data + if self.ada_augment[1] > (self.batch_size * 4 - 1): + authen_out_signs, num_outputs = self.ada_augment.tolist() + r_t_stat = authen_out_signs / num_outputs + sign = 1 if r_t_stat > self.ada_target else -1 + self.ada_aug_p += sign * self.ada_aug_step * num_outputs + self.ada_aug_p = min(1.0, max(0.0, self.ada_aug_p)) + self.ada_augment.mul_(0.0) + return self.ada_aug_p + + +def flatten_dict(init_dict): + res_dict = {} + if type(init_dict) is not dict: + return res_dict + + for k, v in init_dict.items(): + if type(v) == dict: + res_dict.update(flatten_dict(v)) + else: + res_dict[k] = v + return res_dict + + +def setattr_cls_from_kwargs(cls, kwargs): + kwargs = flatten_dict(kwargs) + for key in kwargs.keys(): + value = kwargs[key] + setattr(cls, key, value) + + +def dict2clsattr(train_configs, model_configs): + cfgs = {} + for k, v in chain(train_configs.items(), model_configs.items()): + cfgs[k] = v + + class cfg_container: + pass + cfg_container.train_configs = train_configs + cfg_container.model_configs = model_configs + setattr_cls_from_kwargs(cfg_container, cfgs) + return cfg_container + + +# fix python, numpy, torch seed +def fix_all_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) + + +def setup(rank, world_size, backend="nccl"): + if sys.platform == 'win32': + # Distributed package only covers collective communications with Gloo + # backend and FileStore on Windows platform. Set init_method parameter + # in init_process_group to a local file. + # Example init_method="file:///f:/libtmp/some_file" + init_method = "file:///{your local file path}" + + # initialize the process group + dist.init_process_group( + backend, + init_method=init_method, + rank=rank, + world_size=world_size + ) + else: + # initialize the process group + dist.init_process_group(backend, + init_method="tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']), + rank=rank, + world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def count_parameters(module): + return 'Number of parameters: {}'.format(sum([p.data.nelement() for p in module.parameters()])) + + +def define_sampler(dataset_name, conditional_strategy): + if conditional_strategy != "no": + if dataset_name == "cifar10": + sampler = "class_order_all" + else: + sampler = "class_order_some" + else: + sampler = "default" + return sampler + + +def check_flag_0(batch_size, n_gpus, freeze_layers, checkpoint_folder, architecture, img_size): + assert batch_size % n_gpus == 0, "Batch_size should be divided by the number of gpus." + + if architecture == "dcgan": + assert img_size == 32, "Sry,\ + StudioGAN does not support dcgan models for generation of images larger than 32 resolution." + + if freeze_layers > -1: + assert checkpoint_folder is not None, "Freezing discriminator needs a pre-trained model." + + +def check_flag_1(tempering_type, pos_collected_numerator, conditional_strategy, diff_aug, ada, mixed_precision, + gradient_penalty_for_dis, deep_regret_analysis_for_dis, cr, bcr, zcr, + distributed_data_parallel, synchronized_bn): + assert int(diff_aug) * int(ada) == 0, \ + "You can't simultaneously apply Differentiable Augmentation (DiffAug) and Adaptive Discriminator Augmentation (ADA)." + + assert int(mixed_precision) * int(gradient_penalty_for_dis) == 0, \ + "You can't simultaneously apply mixed precision training (mpc) and Gradient Penalty for WGAN-GP." + + assert int(mixed_precision) * int(deep_regret_analysis_for_dis) == 0, \ + "You can't simultaneously apply mixed precision training (mpc) and Deep Regret Analysis for DRAGAN." + + assert int(cr) * int(bcr) == 0 and int(cr) * int(zcr) == 0, \ + "You can't simultaneously turn on Consistency Reg. (CR) and Improved Consistency Reg. (ICR)." + + assert int(gradient_penalty_for_dis) * int(deep_regret_analysis_for_dis) == 0, \ + "You can't simultaneously apply Gradient Penalty (GP) and Deep Regret Analysis (DRA)." + + if conditional_strategy == "ContraGAN": + assert tempering_type == "constant" or tempering_type == "continuous" or tempering_type == "discrete", \ + "Tempering_type should be one of constant, continuous, or discrete." + + if pos_collected_numerator: + assert conditional_strategy == "ContraGAN", "Pos_collected_numerator option is not appliable except for ContraGAN." + + if distributed_data_parallel: + msg = 'Evaluation results of the image generation with DDP are not exact. ' + \ + 'Please use a single GPU training mode or DataParallel for exact evluation.' + warnings.warn(msg) + + +# Convenience utility to switch off requires_grad +def toggle_grad(model, on, freeze_layers=-1): + try: + if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): + num_blocks = len(model.module.in_dims) + else: + num_blocks = len(model.in_dims) + + assert freeze_layers < num_blocks,\ + "can't not freeze the {fl}th block > total {nb} blocks.".format(fl=freeze_layers, nb=num_blocks) + + if freeze_layers == -1: + for name, param in model.named_parameters(): + param.requires_grad = on + else: + for name, param in model.named_parameters(): + param.requires_grad = on + for layer in range(freeze_layers): + block = "blocks.{layer}".format(layer=layer) + if block in name: + param.requires_grad = False + except: + for name, param in model.named_parameters(): + param.requires_grad = on + + +def set_bn_train(m): + if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): + m.train() + + +def untrack_bn_statistics(m): + if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): + m.track_running_stats = False + + +def track_bn_statistics(m): + if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): + m.track_running_stats = True + + +def set_deterministic_op_train(m): + if isinstance(m, torch.nn.modules.conv.Conv2d): + m.train() + + if isinstance(m, torch.nn.modules.conv.ConvTranspose2d): + m.train() + + if isinstance(m, torch.nn.modules.linear.Linear): + m.train() + + if isinstance(m, torch.nn.modules.Embedding): + m.train() + + +def reset_bn_stat(m): + if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): + m.reset_running_stats() + + +def elapsed_time(start_time): + now = datetime.now() + elapsed = now - start_time + return str(elapsed).split('.')[0] # remove milliseconds + + +def reshape_weight_to_matrix(weight): + weight_mat = weight + dim = 0 + if dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(dim, *[d for d in range(weight_mat.dim()) if d != dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + +def find_string(list_, string): + for i, s in enumerate(list_): + if string == s: + return i + + +def find_and_remove(path): + if os.path.isfile(path): + os.remove(path) + + +def calculate_all_sn(model): + sigmas = {} + with torch.no_grad(): + for name, param in model.named_parameters(): + if "weight" in name and "bn" not in name and "shared" not in name and "deconv" not in name: + if "blocks" in name: + splited_name = name.split('.') + idx = find_string(splited_name, 'blocks') + block_idx = int(splited_name[int(idx + 1)]) + module_idx = int(splited_name[int(idx + 2)]) + operation_name = splited_name[idx + 3] + if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): + operations = model.module.blocks[block_idx][module_idx] + else: + operations = model.blocks[block_idx][module_idx] + operation = getattr(operations, operation_name) + else: + splited_name = name.split('.') + idx = find_string(splited_name, 'module') if isinstance( + model, DataParallel) or isinstance(model, DistributedDataParallel) else -1 + operation_name = splited_name[idx + 1] + if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): + operation = getattr(model.module, operation_name) + else: + operation = getattr(model, operation_name) + + weight_orig = reshape_weight_to_matrix(operation.weight_orig) + weight_u = operation.weight_u + weight_v = operation.weight_v + sigmas[name] = torch.dot(weight_u, torch.mv(weight_orig, weight_v)) + return sigmas + + +def apply_accumulate_stat(generator, acml_step, prior, batch_size, z_dim, num_classes, device): + generator.train() + generator.apply(reset_bn_stat) + for i in range(acml_step): + new_batch_size = random.randint(1, batch_size) + z, fake_labels = sample_latents(prior, new_batch_size, z_dim, 1, num_classes, None, device) + generated_images = generator(z, fake_labels) + generator.eval() + + +def change_generator_mode(gen, gen_copy, bn_stat_OnTheFly, standing_statistics, standing_step, + prior, batch_size, z_dim, num_classes, device, training, counter): + gen_tmp = gen if gen_copy is None else gen_copy + + if training: + gen.train() + gen_tmp.train() + gen_tmp.apply(track_bn_statistics) + return gen_tmp + + if standing_statistics: + if counter > 1: + gen_tmp.eval() + gen_tmp.apply(set_deterministic_op_train) + else: + gen_tmp.train() + apply_accumulate_stat(gen_tmp, standing_step, prior, batch_size, z_dim, num_classes, device) + gen_tmp.eval() + gen_tmp.apply(set_deterministic_op_train) + else: + gen_tmp.eval() + if bn_stat_OnTheFly: + gen_tmp.apply(set_bn_train) + gen_tmp.apply(untrack_bn_statistics) + gen_tmp.apply(set_deterministic_op_train) + return gen_tmp + + +def plot_img_canvas(images, save_path, logger, nrow, logging=True): + directory = dirname(save_path) + + if not exists(abspath(directory)): + os.makedirs(directory) + + save_image(images, save_path, padding=0, nrow=nrow) + if logging: + logger.info("Saved image to {}".format(save_path)) + + +def plot_pr_curve(precision, recall, run_name, logger, log=False): + directory = join('./figures', run_name) + + if not exists(abspath(directory)): + os.makedirs(directory) + + save_path = join(directory, "pr_curve.png") + + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], linestyle='--') + ax.plot(recall, precision) + ax.grid(True) + ax.set_xlabel('Recall (Higher is better)', fontsize=15) + ax.set_ylabel('Precision (Higher is better)', fontsize=15) + fig.tight_layout() + fig.savefig(save_path) + if log: + logger.info("Save image to {}".format(save_path)) + return fig + + +def plot_spectrum_image(real_spectrum, fake_spectrum, run_name, logger, log=False): + directory = join('./figures', run_name) + + if not exists(abspath(directory)): + os.makedirs(directory) + + save_path = join(directory, "dfft_spectrum.png") + + fig = plt.figure() + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + + ax1.imshow(real_spectrum) + ax1.set_title("Spectrum of real images") + + ax2.imshow(fake_spectrum) + ax2.set_title("Spectrum of fake images") + fig.savefig(save_path) + if log: + logger.info("Save image to {}".format(save_path)) + + +def plot_tsne_scatter_plot(df, tsne_results, flag, run_name, logger): + directory = join('./figures', run_name, flag) + + if not exists(abspath(directory)): + os.makedirs(directory) + + save_path = join(directory, "tsne_scatter.png") + + df['tsne-2d-one'] = tsne_results[:, 0] + df['tsne-2d-two'] = tsne_results[:, 1] + plt.figure(figsize=(16, 10)) + sns.scatterplot( + x="tsne-2d-one", y="tsne-2d-two", + hue="labels", + palette=sns.color_palette("hls", 10), + data=df, + legend="full", + alpha=0.5 + ).legend(fontsize=15, loc='upper right') + plt.title("TSNE result of {flag} images".format(flag=flag), fontsize=25) + plt.xlabel('', fontsize=7) + plt.ylabel('', fontsize=7) + plt.savefig(save_path) + logger.info("Save image to {}".format(save_path)) + + +def plot_sim_heatmap(similarity, xlabels, ylabels, run_name, logger, log=False): + directory = join('./figures', run_name) + + if not exists(abspath(directory)): + os.makedirs(directory) + + save_path = join(directory, "sim_heatmap.png") + + sns.set(style="white") + fig, ax = plt.subplots(figsize=(18, 18)) + cmap = sns.diverging_palette(220, 20, as_cmap=True) + # Generate a mask for the upper triangle + mask = np.zeros_like(similarity, dtype=np.bool) + mask[np.triu_indices_from(mask, k=1)] = True + + # Draw the heatmap with the mask and correct aspect ratio + sns.heatmap(similarity, mask=mask, cmap=cmap, center=0.5, + xticklabels=xlabels, yticklabels=ylabels, + square=True, linewidths=.5, fmt='.2f', + annot=True, cbar_kws={"shrink": .5}, vmax=1) + + ax.set_title("Heatmap of cosine similarity scores").set_fontsize(15) + ax.set_xlabel("") + ax.set_ylabel("") + + fig.savefig(save_path) + if log: + logger.info("Save image to {}".format(save_path)) + return fig + + +def save_images_npz(run_name, data_loader, num_samples, num_classes, generator, discriminator, is_generate, + truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): + if is_generate is True: + batch_size = data_loader.batch_size + n_batches = math.ceil(float(num_samples) / float(batch_size)) + else: + batch_size = data_loader.batch_size + total_instance = len(data_loader.dataset) + n_batches = math.ceil(float(num_samples) / float(batch_size)) + data_iter = iter(data_loader) + + data_iter = iter(data_loader) + type = "fake" if is_generate is True else "real" + print("Save {num_samples} {type} images in npz format....".format(num_samples=num_samples, type=type)) + + directory = join('./samples', run_name, type, "npz") + if exists(abspath(directory)): + shutil.rmtree(abspath(directory)) + os.makedirs(directory) + + x = [] + y = [] + with torch.no_grad() if latent_op is False else dummy_context_mgr() as mpc: + for i in tqdm(range(0, n_batches), disable=False): + start = i * batch_size + end = start + batch_size + if is_generate: + images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, + latent_op_step, latent_op_alpha, latent_op_beta, device) + else: + try: + images, labels = next(data_iter) + except StopIteration: + break + + x += [np.uint8(255 * (images.detach().cpu().numpy() + 1) / 2.)] + y += [labels.detach().cpu().numpy()] + x = np.concatenate(x, 0)[:num_samples] + y = np.concatenate(y, 0)[:num_samples] + print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) + npz_filename = join(directory, "samples.npz") + print('Saving npz to %s' % npz_filename) + np.savez(npz_filename, **{'x': x, 'y': y}) + + +def save_images_png(run_name, data_loader, num_samples, num_classes, generator, discriminator, is_generate, + truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): + if is_generate is True: + batch_size = data_loader.batch_size + n_batches = math.ceil(float(num_samples) / float(batch_size)) + else: + batch_size = data_loader.batch_size + total_instance = len(data_loader.dataset) + n_batches = math.ceil(float(num_samples) / float(batch_size)) + data_iter = iter(data_loader) + + data_iter = iter(data_loader) + type = "fake" if is_generate is True else "real" + print("Save {num_samples} {type} images in png format....".format(num_samples=num_samples, type=type)) + + directory = join('./samples', run_name, type, "png") + if exists(abspath(directory)): + shutil.rmtree(abspath(directory)) + os.makedirs(directory) + for f in range(num_classes): + os.makedirs(join(directory, str(f))) + + with torch.no_grad() if latent_op is False else dummy_context_mgr() as mpc: + for i in tqdm(range(0, n_batches), disable=False): + start = i * batch_size + end = start + batch_size + if is_generate: + images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, + latent_op_step, latent_op_alpha, latent_op_beta, device) + else: + try: + images, labels = next(data_iter) + except StopIteration: + break + + for idx, img in enumerate(images.detach()): + if batch_size * i + idx < num_samples: + save_image((img + 1) / 2, join(directory, + str(labels[idx].item()), '{idx}.png'.format(idx=batch_size * i + idx))) + else: + pass + print('Save png to ./generated_images/%s' % run_name) + + +def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): + if isinstance(gen_model, DataParallel) or isinstance(gen_model, DistributedDataParallel): + z_dim = gen_model.module.z_dim + num_classes = gen_model.module.num_classes + conditional_strategy = dis_model.module.conditional_strategy + else: + z_dim = gen_model.z_dim + num_classes = gen_model.num_classes + conditional_strategy = dis_model.conditional_strategy + + zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device, real_label) + + if latent_op: + zs = latent_optimise(zs, fake_labels, gen_model, dis_model, conditional_strategy, latent_op_step, 1.0, + latent_op_alpha, latent_op_beta, False, device) + + with torch.no_grad(): + batch_images = gen_model(zs, fake_labels, evaluation=True) + + return batch_images, list(fake_labels.detach().cpu().numpy()) + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_input): + # def __call__(self, module, module_in, module_out): + self.outputs.append(module_input) + + def clear(self): + self.outputs = [] + + +def calculate_ortho_reg(m, rank): + with torch.enable_grad(): + reg = 1e-6 + param_flat = m.view(m.shape[0], -1) + sym = torch.mm(param_flat, torch.t(param_flat)) + sym -= torch.eye(param_flat.shape[0]).to(rank) + ortho_loss = reg * sym.abs().sum() + return ortho_loss diff --git a/pytorch_pretrained_gans/StudioGAN/utils/model_ops.py b/pytorch_pretrained_gans/StudioGAN/utils/model_ops.py new file mode 100644 index 0000000..0a65e97 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/model_ops.py @@ -0,0 +1,170 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/model_ops.py + + +import torch +import torch.nn as nn +from torch.nn.utils import spectral_norm +from torch.nn import init + + + +def init_weights(modules, initialize): + for module in modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.ConvTranspose2d) + or isinstance(module, nn.Linear)): + if initialize == 'ortho': + init.orthogonal_(module.weight) + if module.bias is not None: + module.bias.data.fill_(0.) + elif initialize == 'N02': + init.normal_(module.weight, 0, 0.02) + if module.bias is not None: + module.bias.data.fill_(0.) + elif initialize in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.fill_(0.) + else: + print('Init style not recognized...') + elif isinstance(module, nn.Embedding): + if initialize == 'ortho': + init.orthogonal_(module.weight) + elif initialize == 'N02': + init.normal_(module.weight, 0, 0.02) + elif initialize in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + else: + pass + + +def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + +def deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True): + return nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + +def linear(in_features, out_features, bias=True): + return nn.Linear(in_features=in_features, out_features=out_features, bias=bias) + +def embedding(num_embeddings, embedding_dim): + return nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + +def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6) + +def sndeconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True): + return spectral_norm(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias), eps=1e-6) + +def snlinear(in_features, out_features, bias=True): + return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6) + +def sn_embedding(num_embeddings, embedding_dim): + return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6) + +def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True): + return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True) + + +class ConditionalBatchNorm2d(nn.Module): + # https://github.com/voletiv/self-attention-GAN-pytorch + def __init__(self, num_features, num_classes, spectral_norm): + super().__init__() + self.num_features = num_features + self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False) + + if spectral_norm: + self.embed0 = sn_embedding(num_classes, num_features) + self.embed1 = sn_embedding(num_classes, num_features) + else: + self.embed0 = embedding(num_classes, num_features) + self.embed1 = embedding(num_classes, num_features) + + def forward(self, x, y): + gain = (1 + self.embed0(y)).view(-1, self.num_features, 1, 1) + bias = self.embed1(y).view(-1, self.num_features, 1, 1) + out = self.bn(x) + return out * gain + bias + + +class ConditionalBatchNorm2d_for_skip_and_shared(nn.Module): + # https://github.com/voletiv/self-attention-GAN-pytorch + def __init__(self, num_features, z_dims_after_concat, spectral_norm): + super().__init__() + self.num_features = num_features + self.bn = batchnorm_2d(num_features, eps=1e-4, momentum=0.1, affine=False) + + if spectral_norm: + self.gain = snlinear(z_dims_after_concat, num_features, bias=False) + self.bias = snlinear(z_dims_after_concat, num_features, bias=False) + else: + self.gain = linear(z_dims_after_concat, num_features, bias=False) + self.bias = linear(z_dims_after_concat, num_features, bias=False) + + def forward(self, x, y): + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + out = self.bn(x) + return out * gain + bias + + +class Self_Attn(nn.Module): + # https://github.com/voletiv/self-attention-GAN-pytorch + def __init__(self, in_channels, spectral_norm): + super(Self_Attn, self).__init__() + self.in_channels = in_channels + + if spectral_norm: + self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False) + else: + self.conv1x1_theta = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_phi = conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_g = conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0, bias=False) + self.conv1x1_attn = conv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, bias=False) + + self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) + self.softmax = nn.Softmax(dim=-1) + self.sigma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + """ + inputs : + x : input feature maps(B X C X H X W) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + _, ch, h, w = x.size() + # Theta path + theta = self.conv1x1_theta(x) + theta = theta.view(-1, ch//8, h*w) + # Phi path + phi = self.conv1x1_phi(x) + phi = self.maxpool(phi) + phi = phi.view(-1, ch//8, h*w//4) + # Attn map + attn = torch.bmm(theta.permute(0, 2, 1), phi) + attn = self.softmax(attn) + # g path + g = self.conv1x1_g(x) + g = self.maxpool(g) + g = g.view(-1, ch//2, h*w//4) + # Attn_g + attn_g = torch.bmm(g, attn.permute(0, 2, 1)) + attn_g = attn_g.view(-1, ch//2, h, w) + attn_g = self.conv1x1_attn(attn_g) + return x + self.sigma*attn_g + diff --git a/pytorch_pretrained_gans/StudioGAN/utils/sample.py b/pytorch_pretrained_gans/StudioGAN/utils/sample.py new file mode 100644 index 0000000..4938117 --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/utils/sample.py @@ -0,0 +1,114 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/utils/sample.py + + +import numpy as np +import random +from numpy import linalg +from math import sin, cos, sqrt + +from .losses import latent_optimise + +import torch +import torch.nn.functional as F +from torch.nn import DataParallel + + +def sample_latents(dist, batch_size, dim, truncated_factor=1, num_classes=None, perturb=None, device=torch.device("cpu"), sampler="default"): + if num_classes: + if sampler == "default": + y_fake = torch.randint(low=0, high=num_classes, size=(batch_size,), dtype=torch.long, device=device) + elif sampler == "class_order_some": + assert batch_size % 8 == 0, "The size of the batches should be a multiple of 8." + num_classes_plot = batch_size // 8 + indices = np.random.permutation(num_classes)[:num_classes_plot] + elif sampler == "class_order_all": + batch_size = num_classes * 8 + indices = [c for c in range(num_classes)] + elif isinstance(sampler, int): + y_fake = torch.tensor([sampler] * batch_size, dtype=torch.long).to(device) + else: + raise NotImplementedError + + if sampler in ["class_order_some", "class_order_all"]: + y_fake = [] + for idx in indices: + y_fake += [idx] * 8 + y_fake = torch.tensor(y_fake, dtype=torch.long).to(device) + else: + y_fake = None + + if isinstance(perturb, float) and perturb > 0.0: + if dist == "gaussian": + latents = torch.randn(batch_size, dim, device=device) / truncated_factor + eps = perturb * torch.randn(batch_size, dim, device=device) + latents_eps = latents + eps + elif dist == "uniform": + latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) + eps = perturb * torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) + latents_eps = latents + eps + elif dist == "hyper_sphere": + latents, latents_eps = random_ball(batch_size, dim, perturb=perturb) + latents, latents_eps = torch.FloatTensor(latents).to(device), torch.FloatTensor(latents_eps).to(device) + return latents, y_fake, latents_eps + else: + if dist == "gaussian": + latents = torch.randn(batch_size, dim, device=device) / truncated_factor + elif dist == "uniform": + latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) + elif dist == "hyper_sphere": + latents = random_ball(batch_size, dim, perturb=perturb).to(device) + return latents, y_fake + + +def random_ball(batch_size, z_dim, perturb=False): + if perturb: + normal = np.random.normal(size=(z_dim, batch_size)) + random_directions = normal / linalg.norm(normal, axis=0) + random_radii = random.random(batch_size) ** (1 / z_dim) + zs = 1.0 * (random_directions * random_radii).T + + normal_perturb = normal + 0.05 * np.random.normal(size=(z_dim, batch_size)) + perturb_random_directions = normal_perturb / linalg.norm(normal_perturb, axis=0) + perturb_random_radii = random.random(batch_size) ** (1 / z_dim) + zs_perturb = 1.0 * (perturb_random_directions * perturb_random_radii).T + return zs, zs_perturb + else: + normal = np.random.normal(size=(z_dim, batch_size)) + random_directions = normal / linalg.norm(normal, axis=0) + random_radii = random.random(batch_size) ** (1 / z_dim) + zs = 1.0 * (random_directions * random_radii).T + return zs + + +# Convenience function to sample an index, not actually a 1-hot +def sample_1hot(batch_size, num_classes, device='cuda'): + return torch.randint(low=0, high=num_classes, size=(batch_size,), + device=device, dtype=torch.int64, requires_grad=False) + + +def make_mask(labels, n_cls, device): + labels = labels.detach().cpu().numpy() + n_samples = labels.shape[0] + mask_multi = np.zeros([n_cls, n_samples]) + for c in range(n_cls): + c_indices = np.where(labels == c) + mask_multi[c, c_indices] = +1 + + mask_multi = torch.tensor(mask_multi).type(torch.long) + return mask_multi.to(device) + + +def target_class_sampler(dataset, target_class): + try: + targets = dataset.data.targets + except: + targets = dataset.labels + weights = [True if target == target_class else False for target in targets] + num_samples = sum(weights) + weights = torch.DoubleTensor(weights) + sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=False) + return num_samples, sampler diff --git a/pytorch_pretrained_gans/StudioGAN/worker.py b/pytorch_pretrained_gans/StudioGAN/worker.py new file mode 100644 index 0000000..4faf59d --- /dev/null +++ b/pytorch_pretrained_gans/StudioGAN/worker.py @@ -0,0 +1,995 @@ +# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN +# The MIT License (MIT) +# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details + +# src/worker.py + + +import numpy as np +import sys +import glob +import random +from scipy import ndimage +from os.path import join +from PIL import Image +from tqdm import tqdm +from datetime import datetime + +from metrics.IS import calculate_incep_score +from metrics.FID import calculate_fid_score +from metrics.F_beta import calculate_f_beta_score +from metrics.Accuracy import calculate_accuracy +from .utils.ada import augment +from .utils.biggan_utils import interp +from .utils.sample import sample_latents, sample_1hot, make_mask, target_class_sampler +from .utils.misc import * +from .utils.losses import calc_derv4gp, calc_derv4dra, calc_derv, latent_optimise, set_temperature +from .utils.losses import Conditional_Contrastive_loss, Proxy_NCA_loss, NT_Xent_loss +from .utils.diff_aug import DiffAugment +from .utils.cr_diff_aug import CR_DiffAug + +import torch +import torch.nn as nn +from torch.nn import DataParallel +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +import torch.nn.functional as F +import torchvision +from torchvision import transforms + + +SAVE_FORMAT = 'step={step:0>3}-Inception_mean={Inception_mean:<.4}-Inception_std={Inception_std:<.4}-FID={FID:<.5}.pth' + +LOG_FORMAT = ( + "Step: {step:>7} " + "Progress: {progress:<.1%} " + "Elapsed: {elapsed} " + "temperature: {temperature:<.6} " + "ada_p: {ada_p:<.6} " + "Discriminator_loss: {dis_loss:<.6} " + "Generator_loss: {gen_loss:<.6} " +) + + +class make_worker(object): + def __init__(self, cfgs, run_name, best_step, logger, writer, n_gpus, gen_model, dis_model, inception_model, Gen_copy, + Gen_ema, train_dataset, eval_dataset, train_dataloader, eval_dataloader, G_optimizer, D_optimizer, G_loss, + D_loss, prev_ada_p, global_rank, local_rank, bn_stat_OnTheFly, checkpoint_dir, mu, sigma, best_fid, + best_fid_checkpoint_path): + + self.cfgs = cfgs + self.run_name = run_name + self.best_step = best_step + self.seed = cfgs.seed + self.dataset_name = cfgs.dataset_name + self.eval_type = cfgs.eval_type + self.logger = logger + self.writer = writer + self.num_workers = cfgs.num_workers + self.n_gpus = n_gpus + + self.gen_model = gen_model + self.dis_model = dis_model + self.inception_model = inception_model + self.Gen_copy = Gen_copy + self.Gen_ema = Gen_ema + + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + + self.freeze_layers = cfgs.freeze_layers + + self.conditional_strategy = cfgs.conditional_strategy + self.pos_collected_numerator = cfgs.pos_collected_numerator + self.z_dim = cfgs.z_dim + self.num_classes = cfgs.num_classes + self.hypersphere_dim = cfgs.hypersphere_dim + self.d_spectral_norm = cfgs.d_spectral_norm + self.g_spectral_norm = cfgs.g_spectral_norm + + self.G_optimizer = G_optimizer + self.D_optimizer = D_optimizer + self.batch_size = cfgs.batch_size + self.g_steps_per_iter = cfgs.g_steps_per_iter + self.d_steps_per_iter = cfgs.d_steps_per_iter + self.accumulation_steps = cfgs.accumulation_steps + self.total_step = cfgs.total_step + + self.G_loss = G_loss + self.D_loss = D_loss + self.contrastive_lambda = cfgs.contrastive_lambda + self.margin = cfgs.margin + self.tempering_type = cfgs.tempering_type + self.tempering_step = cfgs.tempering_step + self.start_temperature = cfgs.start_temperature + self.end_temperature = cfgs.end_temperature + self.weight_clipping_for_dis = cfgs.weight_clipping_for_dis + self.weight_clipping_bound = cfgs.weight_clipping_bound + self.gradient_penalty_for_dis = cfgs.gradient_penalty_for_dis + self.gradient_penalty_lambda = cfgs.gradient_penalty_lambda + self.deep_regret_analysis_for_dis = cfgs.deep_regret_analysis_for_dis + self.regret_penalty_lambda = cfgs.regret_penalty_lambda + self.cr = cfgs.cr + self.cr_lambda = cfgs.cr_lambda + self.bcr = cfgs.bcr + self.real_lambda = cfgs.real_lambda + self.fake_lambda = cfgs.fake_lambda + self.zcr = cfgs.zcr + self.gen_lambda = cfgs.gen_lambda + self.dis_lambda = cfgs.dis_lambda + self.sigma_noise = cfgs.sigma_noise + + self.diff_aug = cfgs.diff_aug + self.ada = cfgs.ada + self.prev_ada_p = prev_ada_p + self.ada_target = cfgs.ada_target + self.ada_length = cfgs.ada_length + self.prior = cfgs.prior + self.truncated_factor = cfgs.truncated_factor + self.ema = cfgs.ema + self.latent_op = cfgs.latent_op + self.latent_op_rate = cfgs.latent_op_rate + self.latent_op_step = cfgs.latent_op_step + self.latent_op_step4eval = cfgs.latent_op_step4eval + self.latent_op_alpha = cfgs.latent_op_alpha + self.latent_op_beta = cfgs.latent_op_beta + self.latent_norm_reg_weight = cfgs.latent_norm_reg_weight + + self.global_rank = global_rank + self.local_rank = local_rank + self.bn_stat_OnTheFly = bn_stat_OnTheFly + self.print_every = cfgs.print_every + self.save_every = cfgs.save_every + self.checkpoint_dir = checkpoint_dir + self.evaluate = cfgs.eval + self.mu = mu + self.sigma = sigma + self.best_fid = best_fid + self.best_fid_checkpoint_path = best_fid_checkpoint_path + self.distributed_data_parallel = cfgs.distributed_data_parallel + self.mixed_precision = cfgs.mixed_precision + self.synchronized_bn = cfgs.synchronized_bn + + self.start_time = datetime.now() + self.l2_loss = torch.nn.MSELoss() + self.ce_loss = torch.nn.CrossEntropyLoss() + self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + self.policy = "color,translation,cutout" + self.counter = 0 + + self.sampler = define_sampler(self.dataset_name, self.conditional_strategy) + + if self.distributed_data_parallel: + self.group = dist.new_group([n for n in range(self.n_gpus)]) + + check_flag_1(self.tempering_type, self.pos_collected_numerator, self.conditional_strategy, self.diff_aug, self.ada, + self.mixed_precision, self.gradient_penalty_for_dis, self.deep_regret_analysis_for_dis, self.cr, self.bcr, + self.zcr, self.distributed_data_parallel, self.synchronized_bn) + + if self.ada: + self.adtv_aug = Adaptive_Augment(self.prev_ada_p, self.ada_target, + self.ada_length, self.batch_size, self.local_rank) + + if self.conditional_strategy in ['ProjGAN', 'ContraGAN', 'Proxy_NCA_GAN']: + if isinstance(self.dis_model, DataParallel) or isinstance(self.dis_model, DistributedDataParallel): + self.embedding_layer = self.dis_model.module.embedding + else: + self.embedding_layer = self.dis_model.embedding + + if self.conditional_strategy == 'ContraGAN': + self.contrastive_criterion = Conditional_Contrastive_loss( + self.local_rank, self.batch_size, self.pos_collected_numerator) + elif self.conditional_strategy == 'Proxy_NCA_GAN': + self.NCA_criterion = Proxy_NCA_loss(self.local_rank, self.embedding_layer, + self.num_classes, self.batch_size) + elif self.conditional_strategy == 'NT_Xent_GAN': + self.NT_Xent_criterion = NT_Xent_loss(self.local_rank, self.batch_size) + else: + pass + + if self.mixed_precision: + self.scaler = torch.cuda.amp.GradScaler() + + if self.dataset_name == "imagenet": + self.num_eval = {'train': 50000, 'valid': 50000} + elif self.dataset_name == "tiny_imagenet": + self.num_eval = {'train': 50000, 'valid': 10000} + elif self.dataset_name == "cifar10": + self.num_eval = {'train': 50000, 'test': 10000} + elif self.dataset_name == "custom": + num_train_images, num_eval_images = len(self.train_dataset.data), len(self.eval_dataset.data) + self.num_eval = {'train': num_train_images, 'valid': num_eval_images} + else: + raise NotImplementedError + + ################################################################################################################################ + + def train(self, current_step, total_step): + self.dis_model.train() + self.gen_model.train() + if self.Gen_copy is not None: + self.Gen_copy.train() + + if self.global_rank == 0: + self.logger.info('Start training....') + step_count = current_step + train_iter = iter(self.train_dataloader) + + self.ada_aug_p = self.adtv_aug.initialize() if self.ada else 'No' + while step_count <= total_step: + # ================== TRAIN D ================== # + toggle_grad(self.dis_model, on=True, freeze_layers=self.freeze_layers) + toggle_grad(self.gen_model, on=False, freeze_layers=-1) + t = set_temperature(self.conditional_strategy, self.tempering_type, self.start_temperature, + self.end_temperature, step_count, self.tempering_step, total_step) + for step_index in range(self.d_steps_per_iter): + self.D_optimizer.zero_grad() + for acml_index in range(self.accumulation_steps): + try: + real_images, real_labels = next(train_iter) + except StopIteration: + train_iter = iter(self.train_dataloader) + real_images, real_labels = next(train_iter) + + real_images, real_labels = real_images.to(self.local_rank), real_labels.to(self.local_rank) + with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: + if self.diff_aug: + real_images = DiffAugment(real_images, policy=self.policy) + if self.ada: + real_images, _ = augment(real_images, self.ada_aug_p) + + if self.zcr: + zs, fake_labels, zs_t = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + self.sigma_noise, self.local_rank) + else: + zs, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + None, self.local_rank) + if self.latent_op: + zs = latent_optimise(zs, fake_labels, self.gen_model, self.dis_model, self.conditional_strategy, + self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, + False, self.local_rank) + + fake_images = self.gen_model(zs, fake_labels) + if self.diff_aug: + fake_images = DiffAugment(fake_images, policy=self.policy) + if self.ada: + fake_images, _ = augment(fake_images, self.ada_aug_p) + + if self.conditional_strategy == "ACGAN": + cls_out_real, dis_out_real = self.dis_model(real_images, real_labels) + cls_out_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_real = self.dis_model(real_images, real_labels) + dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy in ["NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + cls_proxies_real, cls_embed_real, dis_out_real = self.dis_model(real_images, real_labels) + cls_proxies_fake, cls_embed_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + else: + raise NotImplementedError + + dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) + if self.conditional_strategy == "ACGAN": + dis_acml_loss += (self.ce_loss(cls_out_real, real_labels) + + self.ce_loss(cls_out_fake, fake_labels)) + elif self.conditional_strategy == "NT_Xent_GAN": + real_images_aug = CR_DiffAug(real_images) + _, cls_embed_real_aug, dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + dis_acml_loss += self.contrastive_lambda * self.NT_Xent_criterion( + cls_embed_real, cls_embed_real_aug, t) + elif self.conditional_strategy == "Proxy_NCA_GAN": + dis_acml_loss += self.contrastive_lambda * self.NCA_criterion( + cls_embed_real, cls_proxies_real, real_labels) + elif self.conditional_strategy == "ContraGAN": + real_cls_mask = make_mask(real_labels, self.num_classes, self.local_rank) + dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion(cls_embed_real, cls_proxies_real, + real_cls_mask, real_labels, t, self.margin) + else: + pass + + if self.cr: + real_images_aug = CR_DiffAug(real_images) + if self.conditional_strategy == "ACGAN": + cls_out_real_aug, dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + cls_consistency_loss = self.l2_loss(cls_out_real, cls_out_real_aug) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + elif self.conditional_strategy in ["NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + _, cls_embed_real_aug, dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + cls_consistency_loss = self.l2_loss(cls_embed_real, cls_embed_real_aug) + else: + raise NotImplementedError + + consistency_loss = self.l2_loss(dis_out_real, dis_out_real_aug) + if self.conditional_strategy in ["ACGAN", "NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + consistency_loss += cls_consistency_loss + dis_acml_loss += self.cr_lambda * consistency_loss + + if self.bcr: + real_images_aug = CR_DiffAug(real_images) + fake_images_aug = CR_DiffAug(fake_images) + if self.conditional_strategy == "ACGAN": + cls_out_real_aug, dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + cls_out_fake_aug, dis_out_fake_aug = self.dis_model(fake_images_aug, fake_labels) + cls_bcr_real_loss = self.l2_loss(cls_out_real, cls_out_real_aug) + cls_bcr_fake_loss = self.l2_loss(cls_out_fake, cls_out_fake_aug) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_real_aug = self.dis_model(real_images_aug, real_labels) + dis_out_fake_aug = self.dis_model(fake_images_aug, fake_labels) + elif self.conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: + cls_proxies_real_aug, cls_embed_real_aug, dis_out_real_aug = self.dis_model( + real_images_aug, real_labels) + cls_proxies_fake_aug, cls_embed_fake_aug, dis_out_fake_aug = self.dis_model( + fake_images_aug, fake_labels) + cls_bcr_real_loss = self.l2_loss(cls_embed_real, cls_embed_real_aug) + cls_bcr_fake_loss = self.l2_loss(cls_embed_fake, cls_embed_fake_aug) + else: + raise NotImplementedError + + bcr_real_loss = self.l2_loss(dis_out_real, dis_out_real_aug) + bcr_fake_loss = self.l2_loss(dis_out_fake, dis_out_fake_aug) + if self.conditional_strategy in ["ACGAN", "NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + bcr_real_loss += cls_bcr_real_loss + bcr_fake_loss += cls_bcr_fake_loss + dis_acml_loss += self.real_lambda * bcr_real_loss + self.fake_lambda * bcr_fake_loss + + if self.zcr: + fake_images_zaug = self.gen_model(zs_t, fake_labels) + if self.conditional_strategy == "ACGAN": + cls_out_fake_zaug, dis_out_fake_zaug = self.dis_model(fake_images_zaug, fake_labels) + cls_zcr_dis_loss = self.l2_loss(cls_out_fake, cls_out_fake_zaug) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_fake_zaug = self.dis_model(fake_images_zaug, fake_labels) + elif self.conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: + cls_proxies_fake_zaug, cls_embed_fake_zaug, dis_out_fake_zaug = self.dis_model( + fake_images_zaug, fake_labels) + cls_zcr_dis_loss = self.l2_loss(cls_embed_fake, cls_embed_fake_zaug) + else: + raise NotImplementedError + + zcr_dis_loss = self.l2_loss(dis_out_fake, dis_out_fake_zaug) + if self.conditional_strategy in ["ACGAN", "NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + zcr_dis_loss += cls_zcr_dis_loss + dis_acml_loss += self.dis_lambda * zcr_dis_loss + + if self.gradient_penalty_for_dis: + dis_acml_loss += self.gradient_penalty_lambda * calc_derv4gp(self.dis_model, self.conditional_strategy, real_images, + fake_images, real_labels, self.local_rank) + if self.deep_regret_analysis_for_dis: + dis_acml_loss += self.regret_penalty_lambda * calc_derv4dra(self.dis_model, self.conditional_strategy, real_images, + real_labels, self.local_rank) + if self.ada: + self.ada_aug_p = self.adtv_aug.update(dis_out_real) + + dis_acml_loss = dis_acml_loss / self.accumulation_steps + + if self.mixed_precision: + self.scaler.scale(dis_acml_loss).backward() + else: + dis_acml_loss.backward() + + if self.mixed_precision: + self.scaler.step(self.D_optimizer) + self.scaler.update() + else: + self.D_optimizer.step() + + if self.weight_clipping_for_dis: + for p in self.dis_model.parameters(): + p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) + + if step_count % self.print_every == 0 and step_count != 0 and self.global_rank == 0: + if self.d_spectral_norm: + dis_sigmas = calculate_all_sn(self.dis_model) + self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) + + # ================== TRAIN G ================== # + toggle_grad(self.dis_model, False, freeze_layers=-1) + toggle_grad(self.gen_model, True, freeze_layers=-1) + for step_index in range(self.g_steps_per_iter): + self.G_optimizer.zero_grad() + for acml_step in range(self.accumulation_steps): + with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: + if self.zcr: + zs, fake_labels, zs_t = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + self.sigma_noise, self.local_rank) + else: + zs, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + None, self.local_rank) + if self.latent_op: + zs, transport_cost = latent_optimise(zs, fake_labels, self.gen_model, self.dis_model, self.conditional_strategy, + self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, + self.latent_op_beta, True, self.local_rank) + + fake_images = self.gen_model(zs, fake_labels) + if self.diff_aug: + fake_images = DiffAugment(fake_images, policy=self.policy) + if self.ada: + fake_images, _ = augment(fake_images, self.ada_aug_p) + + if self.conditional_strategy == "ACGAN": + cls_out_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy in ["NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + fake_cls_mask = make_mask(fake_labels, self.num_classes, self.local_rank) + cls_proxies_fake, cls_embed_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + else: + raise NotImplementedError + + gen_acml_loss = self.G_loss(dis_out_fake) + + if self.latent_op: + gen_acml_loss += transport_cost * self.latent_norm_reg_weight + + if self.zcr: + fake_images_zaug = self.gen_model(zs_t, fake_labels) + zcr_gen_loss = -1 * self.l2_loss(fake_images, fake_images_zaug) + gen_acml_loss += self.gen_lambda * zcr_gen_loss + + if self.conditional_strategy == "ACGAN": + gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) + elif self.conditional_strategy == "ContraGAN": + gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion( + cls_embed_fake, cls_proxies_fake, fake_cls_mask, fake_labels, t, self.margin) + elif self.conditional_strategy == "Proxy_NCA_GAN": + gen_acml_loss += self.contrastive_lambda * self.NCA_criterion( + cls_embed_fake, cls_proxies_fake, fake_labels) + elif self.conditional_strategy == "NT_Xent_GAN": + fake_images_aug = CR_DiffAug(fake_images) + _, cls_embed_fake_aug, dis_out_fake_aug = self.dis_model(fake_images_aug, fake_labels) + gen_acml_loss += self.contrastive_lambda * self.NT_Xent_criterion( + cls_embed_fake, cls_embed_fake_aug, t) + else: + pass + + gen_acml_loss = gen_acml_loss / self.accumulation_steps + + if self.mixed_precision: + self.scaler.scale(gen_acml_loss).backward() + else: + gen_acml_loss.backward() + + if self.mixed_precision: + self.scaler.step(self.G_optimizer) + self.scaler.update() + else: + self.G_optimizer.step() + + # if ema is True: we update parameters of the Gen_copy in adaptive way. + if self.ema: + self.Gen_ema.update(step_count) + + step_count += 1 + + if step_count % self.print_every == 0 and self.global_rank == 0: + log_message = LOG_FORMAT.format(step=step_count, + progress=step_count / total_step, + elapsed=elapsed_time(self.start_time), + temperature=t, + ada_p=self.ada_aug_p, + dis_loss=dis_acml_loss.item(), + gen_loss=gen_acml_loss.item(), + ) + self.logger.info(log_message) + + if self.g_spectral_norm: + gen_sigmas = calculate_all_sn(self.gen_model) + self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) + + self.writer.add_scalars('Losses', {'discriminator': dis_acml_loss.item(), + 'generator': gen_acml_loss.item()}, step_count) + if self.ada: + self.writer.add_scalar('ada_p', self.ada_aug_p, step_count) + + if step_count % self.save_every == 0 or step_count == total_step: + if self.evaluate: + is_best = self.evaluation(step_count, False, "N/A") + if self.global_rank == 0: + self.save(step_count, is_best) + else: + if self.global_rank == 0: + self.save(step_count, False) + + if self.cfgs.distributed_data_parallel: + dist.barrier(self.group) + + return step_count - 1 + ################################################################################################################################ + + ################################################################################################################################ + + def save(self, step, is_best): + when = "best" if is_best is True else "current" + self.dis_model.eval() + self.gen_model.eval() + if self.Gen_copy is not None: + self.Gen_copy.eval() + + if isinstance(self.gen_model, DataParallel) or isinstance(self.gen_model, DistributedDataParallel): + gen, dis = self.gen_model.module, self.dis_model.module + if self.Gen_copy is not None: + gen_copy = self.Gen_copy.module + else: + gen, dis = self.gen_model, self.dis_model + if self.Gen_copy is not None: + gen_copy = self.Gen_copy + + g_states = {'seed': self.seed, 'run_name': self.run_name, 'step': step, 'best_step': self.best_step, + 'state_dict': gen.state_dict(), 'optimizer': self.G_optimizer.state_dict(), 'ada_p': self.ada_aug_p} + + d_states = {'seed': self.seed, 'run_name': self.run_name, 'step': step, 'best_step': self.best_step, + 'state_dict': dis.state_dict(), 'optimizer': self.D_optimizer.state_dict(), 'ada_p': self.ada_aug_p, + 'best_fid': self.best_fid, 'best_fid_checkpoint_path': self.checkpoint_dir} + + if len(glob.glob(join(self.checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))) >= 1: + find_and_remove( + glob.glob(join(self.checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0]) + find_and_remove( + glob.glob(join(self.checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0]) + + g_checkpoint_output_path = join( + self.checkpoint_dir, "model=G-{when}-weights-step={step}.pth".format(when=when, step=str(step))) + d_checkpoint_output_path = join( + self.checkpoint_dir, "model=D-{when}-weights-step={step}.pth".format(when=when, step=str(step))) + + torch.save(g_states, g_checkpoint_output_path) + torch.save(d_states, d_checkpoint_output_path) + + if when == "best": + if len(glob.glob(join(self.checkpoint_dir, "model=G-current-weights-step*.pth"))) >= 1: + find_and_remove(glob.glob(join(self.checkpoint_dir, "model=G-current-weights-step*.pth"))[0]) + find_and_remove(glob.glob(join(self.checkpoint_dir, "model=D-current-weights-step*.pth"))[0]) + + g_checkpoint_output_path_ = join( + self.checkpoint_dir, "model=G-current-weights-step={step}.pth".format(step=str(step))) + d_checkpoint_output_path_ = join( + self.checkpoint_dir, "model=D-current-weights-step={step}.pth".format(step=str(step))) + + torch.save(g_states, g_checkpoint_output_path_) + torch.save(d_states, d_checkpoint_output_path_) + + if self.Gen_copy is not None: + g_ema_states = {'state_dict': gen_copy.state_dict()} + if len(glob.glob(join(self.checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))) >= 1: + find_and_remove( + glob.glob(join(self.checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]) + + g_ema_checkpoint_output_path = join( + self.checkpoint_dir, "model=G_ema-{when}-weights-step={step}.pth".format(when=when, step=str(step))) + + torch.save(g_ema_states, g_ema_checkpoint_output_path) + + if when == "best": + if len(glob.glob(join(self.checkpoint_dir, "model=G_ema-current-weights-step*.pth".format(when=when)))) >= 1: + find_and_remove( + glob.glob(join(self.checkpoint_dir, "model=G_ema-current-weights-step*.pth".format(when=when)))[0]) + + g_ema_checkpoint_output_path_ = join( + self.checkpoint_dir, "model=G_ema-current-weights-step={step}.pth".format(when=when, step=str(step))) + + torch.save(g_ema_states, g_ema_checkpoint_output_path_) + + if self.logger: + if self.global_rank == 0: + self.logger.info("Save model to {}".format(self.checkpoint_dir)) + + self.dis_model.train() + self.gen_model.train() + if self.Gen_copy is not None: + self.Gen_copy.train() + ################################################################################################################################ + + ################################################################################################################################ + + def evaluation(self, step, standing_statistics, standing_step): + if standing_statistics: + self.counter += 1 + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + if self.global_rank == 0: + self.logger.info("Start Evaluation ({step} Step): {run_name}".format(step=step, run_name=self.run_name)) + is_best = False + num_split, num_run4PR, num_cluster4PR, beta4PR = 1, 10, 20, 8 + + self.dis_model.eval() + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + + fid_score, self.m1, self.s1 = calculate_fid_score(self.eval_dataloader, generator, self.dis_model, self.inception_model, self.num_eval[self.eval_type], + self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, + self.latent_op_beta, self.local_rank, self.logger, self.mu, self.sigma, self.run_name) + + kl_score, kl_std = calculate_incep_score(self.eval_dataloader, generator, self.dis_model, self.inception_model, self.num_eval[self.eval_type], + self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, + self.latent_op_beta, num_split, self.local_rank, self.logger) + + precision, recall, f_beta, f_beta_inv = calculate_f_beta_score(self.eval_dataloader, generator, self.dis_model, self.inception_model, self.num_eval[self.eval_type], + num_run4PR, num_cluster4PR, beta4PR, self.truncated_factor, self.prior, self.latent_op, + self.latent_op_step4eval, self.latent_op_alpha, self.latent_op_beta, self.local_rank, self.logger) + PR_Curve = plot_pr_curve(precision, recall, self.run_name, self.logger) + + if self.conditional_strategy in ['ProjGAN', 'ContraGAN', 'Proxy_NCA_GAN']: + if self.dataset_name == "cifar10": + classes = torch.tensor([c for c in range(self.num_classes)], dtype=torch.long).to(self.local_rank) + labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] + else: + if self.num_classes > 10: + classes = torch.tensor(random.sample(range(0, self.num_classes), 10), + dtype=torch.long).to(self.local_rank) + else: + classes = torch.tensor([c for c in range(self.num_classes)], + dtype=torch.long).to(self.local_rank) + labels = classes.detach().cpu().numpy() + proxies = self.embedding_layer(classes) + sim_p = self.cosine_similarity(proxies.unsqueeze(1), proxies.unsqueeze(0)) + sim_heatmap = plot_sim_heatmap(sim_p.detach().cpu().numpy(), labels, labels, self.run_name, self.logger) + + if self.D_loss.__name__ != "loss_wgan_dis": + real_train_acc, fake_acc = calculate_accuracy(self.train_dataloader, generator, self.dis_model, self.D_loss, self.num_eval[self.eval_type], + self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, self.latent_op_alpha, + self.latent_op_beta, self.local_rank, cr=self.cr, logger=self.logger, eval_generated_sample=True) + + if self.eval_type == 'train': + acc_dict = {'real_train': real_train_acc, 'fake': fake_acc} + else: + real_eval_acc = calculate_accuracy(self.eval_dataloader, generator, self.dis_model, self.D_loss, self.num_eval[self.eval_type], + self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, self.latent_op_alpha, + self. latent_op_beta, self.local_rank, cr=self.cr, logger=self.logger, eval_generated_sample=False) + acc_dict = {'real_train': real_train_acc, 'real_valid': real_eval_acc, 'fake': fake_acc} + + if self.global_rank == 0: + self.writer.add_scalars('Accuracy', acc_dict, step) + + if self.best_fid is None: + self.best_fid, self.best_step, is_best, f_beta_best, f_beta_inv_best = fid_score, step, True, f_beta, f_beta_inv + else: + if fid_score <= self.best_fid: + self.best_fid, self.best_step, is_best, f_beta_best, f_beta_inv_best = fid_score, step, True, f_beta, f_beta_inv + + if self.global_rank == 0: + self.writer.add_scalars( + 'FID score', {'using {type} moments'.format(type=self.eval_type): fid_score}, step) + self.writer.add_scalars('F_beta score', {'{num} generated images'.format( + num=str(self.num_eval[self.eval_type])): f_beta}, step) + self.writer.add_scalars('F_beta_inv score', {'{num} generated images'.format( + num=str(self.num_eval[self.eval_type])): f_beta_inv}, step) + self.writer.add_scalars('IS score', {'{num} generated images'.format( + num=str(self.num_eval[self.eval_type])): kl_score}, step) + self.writer.add_figure('PR_Curve', PR_Curve, global_step=step) + if self.conditional_strategy in ['ProjGAN', 'ContraGAN', 'Proxy_NCA_GAN']: + self.writer.add_figure('Similarity_heatmap', sim_heatmap, global_step=step) + self.logger.info('F_{beta} score (Step: {step}, Using {type} images): {F_beta}'.format( + beta=beta4PR, step=step, type=self.eval_type, F_beta=f_beta)) + self.logger.info('F_1/{beta} score (Step: {step}, Using {type} images): {F_beta_inv}'.format( + beta=beta4PR, step=step, type=self.eval_type, F_beta_inv=f_beta_inv)) + self.logger.info('FID score (Step: {step}, Using {type} moments): {FID}'.format( + step=step, type=self.eval_type, FID=fid_score)) + self.logger.info('Inception score (Step: {step}, {num} generated images): {IS}'.format( + step=step, num=str(self.num_eval[self.eval_type]), IS=kl_score)) + if self.train: + self.logger.info('Best FID score (Step: {step}, Using {type} moments): {FID}'.format( + step=self.best_step, type=self.eval_type, FID=self.best_fid)) + + self.dis_model.train() + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + + return is_best + ################################################################################################################################ + + ################################################################################################################################ + + def save_images(self, is_generate, standing_statistics, standing_step, png=True, npz=True): + if self.global_rank == 0: + self.logger.info('Start save images....') + if standing_statistics: + self.counter += 1 + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + self.dis_model.eval() + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + + if png: + save_images_png(self.run_name, self.eval_dataloader, self.num_eval[self.eval_type], self.num_classes, generator, + self.dis_model, is_generate, self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, + self.latent_op_alpha, self.latent_op_beta, self.local_rank) + if npz: + save_images_npz(self.run_name, self.eval_dataloader, self.num_eval[self.eval_type], self.num_classes, generator, + self.dis_model, is_generate, self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, + self.latent_op_alpha, self.latent_op_beta, self.local_rank) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ + + ################################################################################################################################ + + def run_image_visualization(self, nrow, ncol, standing_statistics, standing_step): + if self.global_rank == 0: + self.logger.info('Start visualize images....') + if standing_statistics: + self.counter += 1 + assert self.batch_size % 8 == 0, "batch size should be devided by 8!" + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + + if self.zcr: + zs, fake_labels, zs_t = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + self.sigma_noise, self.local_rank, sampler=self.sampler) + else: + zs, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, + self.local_rank, sampler=self.sampler) + + if self.latent_op: + zs = latent_optimise(zs, fake_labels, self.gen_model, self.dis_model, self.conditional_strategy, + self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, + False, self.local_rank) + + generated_images = generator(zs, fake_labels, evaluation=True) + + plot_img_canvas((generated_images.detach().cpu() + 1) / 2, "./figures/{run_name}/generated_canvas.png". + format(run_name=self.run_name), self.logger, ncol) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ + + ################################################################################################################################ + + def run_nearest_neighbor(self, nrow, ncol, standing_statistics, standing_step): + if self.global_rank == 0: + self.logger.info('Start nearest neighbor analysis....') + if standing_statistics: + self.counter += 1 + assert self.batch_size % 8 == 0, "batch size should be devided by 8!" + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + + resnet50_model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True) + resnet50_conv = nn.Sequential(*list(resnet50_model.children())[:-1]).to(self.local_rank) + if self.n_gpus > 1: + resnet50_conv = DataParallel(resnet50_conv, output_device=self.local_rank) + resnet50_conv.eval() + + for c in tqdm(range(self.num_classes)): + fake_images, fake_labels = generate_images_for_KNN(self.batch_size, c, generator, self.dis_model, self.truncated_factor, self.prior, self.latent_op, + self.latent_op_step, self.latent_op_alpha, self.latent_op_beta, self.local_rank) + fake_image = torch.unsqueeze(fake_images[0], dim=0) + fake_anchor_embedding = torch.squeeze(resnet50_conv((fake_image + 1) / 2)) + + num_samples, target_sampler = target_class_sampler(self.train_dataset, c) + train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, sampler=target_sampler, + num_workers=self.num_workers, pin_memory=True) + train_iter = iter(train_dataloader) + for batch_idx in range(num_samples // self.batch_size): + real_images, real_labels = next(train_iter) + real_images = real_images.to(self.local_rank) + real_embeddings = torch.squeeze(resnet50_conv((real_images + 1) / 2)) + if batch_idx == 0: + distances = torch.square( + real_embeddings - fake_anchor_embedding).mean(dim=1).detach().cpu().numpy() + holder = real_images.detach().cpu().numpy() + else: + distances = np.concatenate([distances, torch.square( + real_embeddings - fake_anchor_embedding).mean(dim=1).detach().cpu().numpy()], axis=0) + holder = np.concatenate([holder, real_images.detach().cpu().numpy()], axis=0) + + nearest_indices = (-distances).argsort()[-(ncol - 1):][::-1] + if c % nrow == 0: + canvas = np.concatenate([fake_image.detach().cpu().numpy(), holder[nearest_indices]], axis=0) + elif c % nrow == nrow - 1: + row_images = np.concatenate([fake_image.detach().cpu().numpy(), holder[nearest_indices]], axis=0) + canvas = np.concatenate((canvas, row_images), axis=0) + plot_img_canvas((torch.from_numpy(canvas) + 1) / 2, "./figures/{run_name}/Fake_anchor_{ncol}NN_{cls}_classes.png". + format(run_name=self.run_name, ncol=ncol, cls=c + 1), self.logger, ncol, logging=False) + else: + row_images = np.concatenate([fake_image.detach().cpu().numpy(), holder[nearest_indices]], axis=0) + canvas = np.concatenate((canvas, row_images), axis=0) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ + + ################################################################################################################################ + + def run_linear_interpolation(self, nrow, ncol, fix_z, fix_y, standing_statistics, standing_step, num_images=100): + if self.global_rank == 0: + self.logger.info('Start linear interpolation analysis....') + if standing_statistics: + self.counter += 1 + assert self.batch_size % 8 == 0, "batch size should be devided by 8!" + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + shared = generator.module.shared if isinstance(generator, DataParallel) or isinstance( + generator, DistributedDataParallel) else generator.shared + assert int(fix_z) * int(fix_y) != 1, "unable to switch fix_z and fix_y on together!" + + for num in tqdm(range(num_images)): + if fix_z: + zs = torch.randn(nrow, 1, self.z_dim, device=self.local_rank) + zs = zs.repeat(1, ncol, 1).view(-1, self.z_dim) + name = "fix_z" + else: + zs = interp(torch.randn(nrow, 1, self.z_dim, device=self.local_rank), + torch.randn(nrow, 1, self.z_dim, device=self.local_rank), + ncol - 2).view(-1, self.z_dim) + + if fix_y: + ys = sample_1hot(nrow, self.num_classes, device=self.local_rank) + ys = shared(ys).view(nrow, 1, -1) + ys = ys.repeat(1, ncol, 1).view(nrow * (ncol), -1) + name = "fix_y" + else: + ys = interp(shared(sample_1hot(nrow, self.num_classes)).view(nrow, 1, -1), + shared(sample_1hot(nrow, self.num_classes)).view(nrow, 1, -1), + ncol - 2).view(nrow * (ncol), -1) + + interpolated_images = generator(zs, None, shared_label=ys, evaluation=True) + + plot_img_canvas((interpolated_images.detach().cpu() + 1) / 2, "./figures/{run_name}/{num}_Interpolated_images_{fix_flag}.png". + format(num=num, run_name=self.run_name, fix_flag=name), self.logger, ncol, logging=False) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ + + ################################################################################################################################ + + def run_frequency_analysis(self, num_images, standing_statistics, standing_step): + if self.global_rank == 0: + self.logger.info('Start frequency analysis....') + if standing_statistics: + self.counter += 1 + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + + train_iter = iter(self.train_dataloader) + num_batches = num_images // self.batch_size + for i in range(num_batches): + if self.zcr: + zs, fake_labels, zs_t = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + self.sigma_noise, self.local_rank) + else: + zs, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + None, self.local_rank) + + if self.latent_op: + zs = latent_optimise(zs, fake_labels, self.gen_model, self.dis_model, self.conditional_strategy, + self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, + False, self.local_rank) + + real_images, real_labels = next(train_iter) + fake_images = generator(zs, fake_labels, evaluation=True).detach().cpu().numpy() + + real_images = np.asarray((real_images + 1) * 127.5, np.uint8) + fake_images = np.asarray((fake_images + 1) * 127.5, np.uint8) + + if i == 0: + real_array = real_images + fake_array = fake_images + else: + real_array = np.concatenate([real_array, real_images], axis=0) + fake_array = np.concatenate([fake_array, fake_images], axis=0) + + N, C, H, W = np.shape(real_array) + real_r, real_g, real_b = real_array[:, 0, :, :], real_array[:, 1, :, :], real_array[:, 2, :, :] + real_gray = 0.2989 * real_r + 0.5870 * real_g + 0.1140 * real_b + fake_r, fake_g, fake_b = fake_array[:, 0, :, :], fake_array[:, 1, :, :], fake_array[:, 2, :, :] + fake_gray = 0.2989 * fake_r + 0.5870 * fake_g + 0.1140 * fake_b + for j in tqdm(range(N)): + real_gray_f = np.fft.fft2(real_gray[j] - ndimage.median_filter(real_gray[j], size=H // 8)) + fake_gray_f = np.fft.fft2(fake_gray[j] - ndimage.median_filter(fake_gray[j], size=H // 8)) + + real_gray_f_shifted = np.fft.fftshift(real_gray_f) + fake_gray_f_shifted = np.fft.fftshift(fake_gray_f) + + if j == 0: + real_gray_spectrum = 20 * np.log(np.abs(real_gray_f_shifted)) / N + fake_gray_spectrum = 20 * np.log(np.abs(fake_gray_f_shifted)) / N + else: + real_gray_spectrum += 20 * np.log(np.abs(real_gray_f_shifted)) / N + fake_gray_spectrum += 20 * np.log(np.abs(fake_gray_f_shifted)) / N + + plot_spectrum_image(real_gray_spectrum, fake_gray_spectrum, self.run_name, self.logger) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ + + ################################################################################################################################ + + def run_tsne(self, dataloader, standing_statistics, standing_step): + if self.global_rank == 0: + self.logger.info('Start tsne analysis....') + if standing_statistics: + self.counter += 1 + with torch.no_grad() if self.latent_op is False else dummy_context_mgr() as mpc: + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=False, counter=self.counter) + if isinstance(self.gen_model, DataParallel) or isinstance(self.gen_model, DistributedDataParallel): + dis_model = self.dis_model.module + else: + dis_model = self.dis_model + + save_output = SaveOutput() + hook_handles = [] + real, fake = {}, {} + tsne_iter = iter(dataloader) + num_batches = len(dataloader.dataset) // self.batch_size + for name, layer in dis_model.named_children(): + if name == "linear1": + handle = layer.register_forward_pre_hook(save_output) + hook_handles.append(handle) + + for i in range(num_batches): + if self.zcr: + zs, fake_labels, zs_t = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + self.sigma_noise, self.local_rank) + else: + zs, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, + None, self.local_rank) + + if self.latent_op: + zs = latent_optimise(zs, fake_labels, self.gen_model, self.dis_model, self.conditional_strategy, + self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, + False, self.local_rank) + + real_images, real_labels = next(tsne_iter) + real_images, real_labels = real_images.to(self.local_rank), real_labels.to(self.local_rank) + fake_images = generator(zs, fake_labels, evaluation=True) + + if self.conditional_strategy == "ACGAN": + cls_out_real, dis_out_real = self.dis_model(real_images, real_labels) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_real = self.dis_model(real_images, real_labels) + elif self.conditional_strategy in ["NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + cls_proxies_real, cls_embed_real, dis_out_real = self.dis_model(real_images, real_labels) + else: + raise NotImplementedError + + if i == 0: + real["embeds"] = save_output.outputs[0][0].detach().cpu().numpy() + real["labels"] = real_labels.detach().cpu().numpy() + else: + real["embeds"] = np.concatenate( + [real["embeds"], save_output.outputs[0][0].cpu().detach().numpy()], axis=0) + real["labels"] = np.concatenate([real["labels"], real_labels.detach().cpu().numpy()]) + + save_output.clear() + + if self.conditional_strategy == "ACGAN": + cls_out_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy == "ProjGAN" or self.conditional_strategy == "no": + dis_out_fake = self.dis_model(fake_images, fake_labels) + elif self.conditional_strategy in ["NT_Xent_GAN", "Proxy_NCA_GAN", "ContraGAN"]: + cls_proxies_fake, cls_embed_fake, dis_out_fake = self.dis_model(fake_images, fake_labels) + else: + raise NotImplementedError + + if i == 0: + fake["embeds"] = save_output.outputs[0][0].detach().cpu().numpy() + fake["labels"] = fake_labels.detach().cpu().numpy() + else: + fake["embeds"] = np.concatenate( + [fake["embeds"], save_output.outputs[0][0].cpu().detach().numpy()], axis=0) + fake["labels"] = np.concatenate([fake["labels"], fake_labels.detach().cpu().numpy()]) + + save_output.clear() + + # t-SNE + from sklearn.manifold import TSNE + tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) + real_tsne_results = tsne.fit_transform(real["embeds"]) + plot_tsne_scatter_plot(real, real_tsne_results, "real", self.run_name, self.logger) + + fake_tsne_results = tsne.fit_transform(fake["embeds"]) + plot_tsne_scatter_plot(fake, fake_tsne_results, "fake", self.run_name, self.logger) + + generator = change_generator_mode(self.gen_model, self.Gen_copy, self.bn_stat_OnTheFly, standing_statistics, standing_step, + self.prior, self.batch_size, self.z_dim, self.num_classes, self.local_rank, training=True, counter=self.counter) + ################################################################################################################################ diff --git a/pytorch_pretrained_gans/__init__.py b/pytorch_pretrained_gans/__init__.py new file mode 100644 index 0000000..e40385e --- /dev/null +++ b/pytorch_pretrained_gans/__init__.py @@ -0,0 +1,25 @@ +from .BigBiGAN.gan_load import make_bigbigan +from .self_conditioned import make_selfcond_gan +from .stylegan2_ada_pytorch import make_stylegan2 +from .StudioGAN import make_studiogan +from .CIPS import make_cips +from .BigGAN import make_biggan + + +def make_gan(*, gan_type, **kwargs): + t = gan_type.lower() + if t == 'bigbigan': + G = make_bigbigan(**kwargs) + elif t == 'selfconditionedgan': + G = make_selfcond_gan(**kwargs) + elif t == 'studiogan': + G = make_studiogan(**kwargs) + elif t == 'stylegan2': + G = make_stylegan2(**kwargs) + elif t == 'cips': + G = make_cips(**kwargs) + elif t == 'biggan': + G = make_biggan(**kwargs) + else: + raise NotImplementedError(f'Unrecognized GAN type: {gan_type}') + return G diff --git a/pytorch_pretrained_gans/self_conditioned/__init__.py b/pytorch_pretrained_gans/self_conditioned/__init__.py new file mode 100644 index 0000000..3d0c348 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/__init__.py @@ -0,0 +1,123 @@ +import torch +from torch.utils import model_zoo +from .gan_training.models import generator_dict + +# Config, adapted from: +# - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/default.yaml +# - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/unconditional.yaml +# - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/selfcondgan.yaml +configs = { + 'unconditional': { + 'generator': { + 'name': 'resnet2', + 'kwargs': {}, + # Unconditional + 'nlabels': 1, + 'conditioning': 'unconditional', + }, + 'z_dist': { + 'dim': 256 + }, + 'data': { + 'img_size': 128 + }, + 'pretrained': { + 'model': 'http://selfcondgan.csail.mit.edu/weights/uncondgan_i_model.pt' + } + }, + 'self_conditioned': { + 'generator': { + 'name': 'resnet2', + 'kwargs': {}, + # Self-conditional + 'nlabels': 100, + 'conditioning': 'embedding', + }, + 'z_dist': { + 'dim': 256 + }, + 'data': { + 'img_size': 128 + }, + 'pretrained': { + 'model': 'http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_model.pt' + } + } +} + + +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format and add metadata (dim_z) """ + + def __init__(self, generator, dim_z, nlabels): + super().__init__() + self.G = generator + self.dim_z = dim_z + self.conditional = True + self.num_classes = nlabels + + def forward(self, z, y=None, return_y=False): + if y is None: + y = self.sample_class(batch_size=z.shape[0], device=z.device) + else: + y = y.to(z.device) + x = self.G(z, y) + return (x, y) if return_y else x + + def sample_latent(self, batch_size=None, device='cpu'): + z = torch.randn(size=(batch_size, self.dim_z), device=device) + # z = truncated_noise_sample(truncation=self.truncation, batch_size=batch_size) + # z = torch.from_numpy(z).to(device) + return z + + def sample_class(self, batch_size=None, device='cpu'): + y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) + return y + + +def make_selfcond_gan(model_name='self_conditioned'): + """ A helper function for loading a (pretrained) GAN """ + + # Get generator configuration + assert model_name in {'self_conditioned', 'unconditional'} + config = configs[model_name] + + # Create GAN + Generator = generator_dict[config['generator']['name']] + generator = Generator( + z_dim=config['z_dist']['dim'], + nlabels=config['generator']['nlabels'], + size=config['data']['img_size'], + conditioning=config['generator']['conditioning'], + **config['generator']['kwargs'] + ) + + # Load checkpoint + checkpoint = model_zoo.load_url(config['pretrained']['model'], map_location='cpu') + generator.load_state_dict(checkpoint['generator']) + print(f"Loaded pretrained GAN weights (iteration: {checkpoint['it']})") + + # Wrap GAN + G = GeneratorWrapper( + generator=generator, + dim_z=config['z_dist']['dim'], + nlabels=config['generator']['nlabels'] + ).eval() + + return G + + +if __name__ == "__main__": + + # Load model + G = make_selfcond_gan('self-conditioned') + print(f'Parameters: {sum(p.numel() for p in G.parameters()) / 10**6} million') + + # Example usage + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + G.to(device) + with torch.no_grad(): + z = torch.randn(7, G.dim_z, requires_grad=False, device=device) + x = G(z) + print(f'Input shape: {z.shape}') + print(f'Output shape: {x.shape}') diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/__init__.py b/pytorch_pretrained_gans/self_conditioned/gan_training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/checkpoints.py b/pytorch_pretrained_gans/self_conditioned/gan_training/checkpoints.py new file mode 100644 index 0000000..7f9f462 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/checkpoints.py @@ -0,0 +1,163 @@ +import os, pickle +import urllib +import torch +import numpy as np +from torch.utils import model_zoo + + +class CheckpointIO(object): + ''' CheckpointIO class. + + It handles saving and loading checkpoints. + + Args: + checkpoint_dir (str): path where checkpoints are saved + ''' + + def __init__(self, checkpoint_dir='./chkpts', **kwargs): + self.module_dict = kwargs + self.checkpoint_dir = checkpoint_dir + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + def register_modules(self, **kwargs): + ''' Registers modules in current module dictionary. + ''' + self.module_dict.update(kwargs) + + def save(self, filename, **kwargs): + ''' Saves the current module dictionary. + + Args: + filename (str): name of output file + ''' + if not os.path.isabs(filename): + filename = os.path.join(self.checkpoint_dir, filename) + + outdict = kwargs + for k, v in self.module_dict.items(): + outdict[k] = v.state_dict() + torch.save(outdict, filename) + + def load(self, filename, pretrained={}): + '''Loads a module dictionary from local file or url. + + Args: + filename (str): name of saved module dictionary + ''' + if 'model' in pretrained: + filename = pretrained['model'] + if is_url(filename): + return self.load_url(filename) + else: + return self.load_file(filename) + + def load_file(self, filename): + '''Loads a module dictionary from file. + + Args: + filename (str): name of saved module dictionary + ''' + + if not os.path.isabs(filename): + filename = os.path.join(self.checkpoint_dir, filename) + + if os.path.exists(filename): + print('=> Loading checkpoint from local file...', filename) + state_dict = torch.load(filename) + scalars = self.parse_state_dict(state_dict) + return scalars + else: + print('File not found', filename) + raise FileNotFoundError + + def load_url(self, url): + '''Load a module dictionary from url. + + Args: + url (str): url to saved model + ''' + print('=> Loading checkpoint from url...', url) + state_dict = model_zoo.load_url(url, model_dir=self.checkpoint_dir, progress=True) + scalars = self.parse_state_dict(state_dict) + return scalars + + def parse_state_dict(self, state_dict): + '''Parse state_dict of model and return scalars. + + Args: + state_dict (dict): State dict of model + ''' + for k, v in self.module_dict.items(): + if k in state_dict: + v.load_state_dict(state_dict[k]) + else: + print('Warning: Could not find %s in checkpoint!' % k) + scalars = { + k: v + for k, v in state_dict.items() if k not in self.module_dict + } + return scalars + + def load_clusterer(self, it, load_samples, pretrained={}): + if 'clusterer' in pretrained: + pretrained_file = os.path.join(self.checkpoint_dir, 'pretrained_clusterer.pkl') + if not os.path.exists(pretrained_file): + import cloudpickle as cp + from urllib.request import urlopen + print('Loading pretrained clusterer from', pretrained['clusterer']) + clusterer = cp.load(urlopen(pretrained['clusterer'])) + print('Saving pretrained clusterer to', pretrained_file) + with open(pretrained_file, 'wb') as f: + f.write(pickle.dumps(clusterer)) + else: + with open(pretrained_file, 'rb') as f: + clusterer = pickle.load(f) + return clusterer + else: + print('Loading clusterer:') + with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'rb') as f: + clusterer = pickle.load(f) + + if load_samples: + print('Loading cluster samples:') + with np.load(os.path.join(self.checkpoint_dir, 'cluster_samples.npz')) as f: + x = f['x'] + clusterer.x = torch.from_numpy(x) + return clusterer + + def load_models(self, it, pretrained={}, load_samples=False): + try: + load_dict = self.load('model_%08d.pt' % it, pretrained) + epoch_idx = load_dict.get('epoch_idx', -1) + except Exception as e: #models are not dataparallel modules + print('Trying again to load w/o data parallel modules') + try: + for name, module in self.module_dict.items(): + if isinstance(module, torch.nn.DataParallel): + self.module_dict[name] = module.module + load_dict = self.load('model_%08d.pt' % it, pretrained) + epoch_idx = load_dict.get('epoch_idx', -1) + except FileNotFoundError as e: + print(e) + print("Models not found") + it = epoch_idx = -1 + + try: + clusterer = self.load_clusterer(it, load_samples, pretrained) + except FileNotFoundError as e: + clusterer = None + + return it, epoch_idx, clusterer + + def save_clusterer(self, clusterer, it): + with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'wb') as f: + #hack: only save changing data + x = clusterer.x + clusterer.x = None + pickle.dump(clusterer, f) + clusterer.x = x + +def is_url(url): + scheme = urllib.parse.urlparse(url).scheme + return scheme in ('http', 'https') diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/config.py b/pytorch_pretrained_gans/self_conditioned/gan_training/config.py new file mode 100644 index 0000000..17342c4 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/config.py @@ -0,0 +1,116 @@ +import yaml +from torch import optim +from os import path +from gan_training.models import generator_dict, discriminator_dict +from gan_training.train import toggle_grad +from clusterers import clusterer_dict + + +# General config +def load_config(path, default_path): + ''' Loads config file. + + Args: + path (str): path to config file + default_path (bool): whether to use default path + ''' + # Load configuration from file itself + with open(path, 'r') as f: + cfg_special = yaml.load(f) + + # Check if we should inherit from a config + inherit_from = cfg_special.get('inherit_from') + + # If yes, load this config first as default + # If no, use the default_path + if inherit_from is not None: + cfg = load_config(inherit_from, default_path) + elif default_path is not None: + with open(default_path, 'r') as f: + cfg = yaml.load(f) + else: + cfg = dict() + + # Include main configuration + update_recursive(cfg, cfg_special) + + return cfg + + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + + ''' + for k, v in dict2.items(): + # Add item if not yet in dict1 + if k not in dict1: + dict1[k] = None + # Update + if isinstance(dict1[k], dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + + +def get_clusterer(config): + return clusterer_dict[config['clusterer']['name']] + + +def build_models(config): + # Get classes + Generator = generator_dict[config['generator']['name']] + Discriminator = discriminator_dict[config['discriminator']['name']] + + # Build models + generator = Generator(z_dim=config['z_dist']['dim'], + nlabels=config['generator']['nlabels'], + size=config['data']['img_size'], + conditioning=config['generator']['conditioning'], + **config['generator']['kwargs']) + discriminator = Discriminator( + nlabels=config['discriminator']['nlabels'], + conditioning=config['discriminator']['conditioning'], + size=config['data']['img_size'], + **config['discriminator']['kwargs']) + + return generator, discriminator + + +def build_optimizers(generator, discriminator, config): + optimizer = config['training']['optimizer'] + lr_g = config['training']['lr_g'] + lr_d = config['training']['lr_d'] + + + toggle_grad(generator, True) + toggle_grad(discriminator, True) + + g_params = generator.parameters() + d_params = discriminator.parameters() + + if optimizer == 'rmsprop': + g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8) + d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8) + elif optimizer == 'adam': + beta1 = config['training']['beta1'] + beta2 = config['training']['beta2'] + g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(beta1, beta2), eps=1e-8) + d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(beta1, beta2), eps=1e-8) + elif optimizer == 'sgd': + g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.) + d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.) + + return g_optimizer, d_optimizer + + +# Some utility functions +def get_parameter_groups(parameters, gradient_scales, base_lr): + param_groups = [] + for p in parameters: + c = gradient_scales.get(p, 1.) + param_groups.append({'params': [p], 'lr': c * base_lr}) + return param_groups diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/distributions.py b/pytorch_pretrained_gans/self_conditioned/gan_training/distributions.py new file mode 100644 index 0000000..bba8026 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/distributions.py @@ -0,0 +1,43 @@ +import torch +from torch import distributions + + +def get_zdist(dist_name, dim, device=None): + # Get distribution + if dist_name == 'uniform': + low = -torch.ones(dim, device=device) + high = torch.ones(dim, device=device) + zdist = distributions.Uniform(low, high) + elif dist_name == 'gauss': + mu = torch.zeros(dim, device=device) + scale = torch.ones(dim, device=device) + zdist = distributions.Normal(mu, scale) + else: + raise NotImplementedError + + # Add dim attribute + zdist.dim = dim + + return zdist + + +def get_ydist(nlabels, device=None): + logits = torch.zeros(nlabels, device=device) + ydist = distributions.categorical.Categorical(logits=logits) + + # Add nlabels attribute + ydist.nlabels = nlabels + + return ydist + + +def interpolate_sphere(z1, z2, t): + p = (z1 * z2).sum(dim=-1, keepdim=True) + p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt() + p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt() + omega = torch.acos(p) + s1 = torch.sin((1-t)*omega)/torch.sin(omega) + s2 = torch.sin(t*omega)/torch.sin(omega) + z = s1 * z1 + s2 * z2 + + return z diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/eval.py b/pytorch_pretrained_gans/self_conditioned/gan_training/eval.py new file mode 100644 index 0000000..3ba29c7 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/eval.py @@ -0,0 +1,80 @@ +import numpy as np +import torch +from torch.nn import functional as F + +from gan_training.metrics import inception_score + +class Evaluator(object): + def __init__(self, + generator, + zdist, + ydist, + train_loader, + clusterer, + batch_size=64, + inception_nsamples=10000, + device=None): + self.generator = generator + self.clusterer = clusterer + self.train_loader = train_loader + self.zdist = zdist + self.ydist = ydist + self.inception_nsamples = inception_nsamples + self.batch_size = batch_size + self.device = device + + def sample_z(self, batch_size): + return self.zdist.sample((batch_size, )).to(self.device) + + def get_y(self, x, y): + return self.clusterer.get_labels(x, y).to(self.device) + + def get_fake_real_samples(self, N): + ''' returns N fake images and N real images in pytorch form''' + with torch.no_grad(): + self.generator.eval() + fake_imgs = [] + real_imgs = [] + while len(fake_imgs) < N: + for x_real, y_gt in self.train_loader: + x_real = x_real.cuda() + z = self.sample_z(x_real.size(0)) + y = self.get_y(x_real, y_gt) + samples = self.generator(z, y) + samples = [s.data.cpu() for s in samples] + fake_imgs.extend(samples) + real_batch = [img.data.cpu() for img in x_real] + real_imgs.extend(real_batch) + assert (len(real_imgs) == len(fake_imgs)) + if len(fake_imgs) >= N: + fake_imgs = fake_imgs[:N] + real_imgs = real_imgs[:N] + return fake_imgs, real_imgs + + def compute_inception_score(self): + imgs, _ = self.get_fake_real_samples(self.inception_nsamples) + imgs = [img.numpy() for img in imgs] + score, score_std = inception_score(imgs, + device=self.device, + resize=True, + splits=1) + + return score, score_std + + def create_samples(self, z, y=None): + self.generator.eval() + batch_size = z.size(0) + # Parse y + if y is None: + raise NotImplementedError() + elif isinstance(y, int): + y = torch.full((batch_size, ), + y, + device=self.device, + dtype=torch.int64) + # Sample x + with torch.no_grad(): + x = self.generator(z, y) + return x + + diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/inputs.py b/pytorch_pretrained_gans/self_conditioned/gan_training/inputs.py new file mode 100644 index 0000000..2c24e69 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/inputs.py @@ -0,0 +1,217 @@ +import torch +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import numpy as np + +import os +import torch.utils.data as data +from torchvision.datasets.folder import default_loader +from PIL import Image +import random + +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def get_dataset(name, + data_dir, + size=64, + lsun_categories=None, + deterministic=False, + transform=None): + + transform = transforms.Compose([ + t for t in [ + transforms.Resize(size), + transforms.CenterCrop(size), + (not deterministic) and transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + (not deterministic) and + transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())), + ] if t is not False + ]) if transform == None else transform + + if name == 'image': + print('Using image labels') + dataset = datasets.ImageFolder(data_dir, transform) + nlabels = len(dataset.classes) + elif name == 'webp': + print('Using no labels from webp') + dataset = CachedImageFolder(data_dir, transform) + nlabels = len(dataset.classes) + elif name == 'npy': + # Only support normalization for now + dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy']) + nlabels = len(dataset.classes) + elif name == 'cifar10': + dataset = datasets.CIFAR10(root=data_dir, + train=True, + download=True, + transform=transform) + nlabels = 10 + elif name == 'stacked_mnist': + dataset = StackedMNIST(data_dir, + transform=transforms.Compose([ + transforms.Resize(size), + transforms.CenterCrop(size), + transforms.ToTensor(), + transforms.Normalize((0.5, ), (0.5, )) + ])) + nlabels = 1000 + elif name == 'lsun': + if lsun_categories is None: + lsun_categories = 'train' + dataset = datasets.LSUN(data_dir, lsun_categories, transform) + nlabels = len(dataset.classes) + elif name == 'lsun_class': + dataset = datasets.LSUNClass(data_dir, + transform, + target_transform=(lambda t: 0)) + nlabels = 1 + else: + raise NotImplemented + return dataset, nlabels + +class CachedImageFolder(data.Dataset): + """ + A version of torchvision.dataset.ImageFolder that takes advantage + of cached filename lists. + photo/park/004234.jpg + photo/park/004236.jpg + photo/park/004237.jpg + """ + + def __init__(self, root, transform=None, loader=default_loader): + classes, class_to_idx = find_classes(root) + self.imgs = make_class_dataset(root, class_to_idx) + if len(self.imgs) == 0: + raise RuntimeError("Found 0 images within: %s" % root) + self.root = root + self.classes = classes + self.class_to_idx = class_to_idx + self.transform = transform + self.loader = loader + + def __getitem__(self, index): + path, classidx = self.imgs[index] + source = self.loader(path) + if self.transform is not None: + source = self.transform(source) + return source, classidx + + def __len__(self): + return len(self.imgs) + +class StackedMNIST(data.Dataset): + def __init__(self, data_dir, transform, batch_size=100000): + super().__init__() + self.channel1 = datasets.MNIST(data_dir, + transform=transform, + train=True, + download=True) + self.channel2 = datasets.MNIST(data_dir, + transform=transform, + train=True, + download=True) + self.channel3 = datasets.MNIST(data_dir, + transform=transform, + train=True, + download=True) + self.indices = { + k: (random.randint(0, + len(self.channel1) - 1), + random.randint(0, + len(self.channel1) - 1), + random.randint(0, + len(self.channel1) - 1)) + for k in range(batch_size) + } + + def __getitem__(self, index): + index1, index2, index3 = self.indices[index] + x1, y1 = self.channel1[index1] + x2, y2 = self.channel2[index2] + x3, y3 = self.channel3[index3] + return torch.cat([x1, x2, x3], dim=0), y1 * 100 + y2 * 10 + y3 + + def __len__(self): + return len(self.indices) + + +def is_npy_file(path): + return path.endswith('.npy') or path.endswith('.NPY') + + +def walk_image_files(rootdir): + print(rootdir) + if os.path.isfile('%s.txt' % rootdir): + print('Loading file list from %s.txt instead of scanning dir' % + rootdir) + basedir = os.path.dirname(rootdir) + with open('%s.txt' % rootdir) as f: + result = sorted([ + os.path.join(basedir, line.strip()) for line in f.readlines() + ]) + import random + random.Random(1).shuffle(result) + return result + result = [] + + IMG_EXTENSIONS = [ + '.jpg', + '.JPG', + '.jpeg', + '.JPEG', + '.png', + '.PNG', + '.ppm', + '.PPM', + '.bmp', + '.BMP', + ] + + for dirname, _, fnames in sorted(os.walk(rootdir)): + for fname in sorted(fnames): + if any(fname.endswith(extension) + for extension in IMG_EXTENSIONS) or is_npy_file(fname): + result.append(os.path.join(dirname, fname)) + return result + + +def find_classes(dir): + classes = [ + d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) + ] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + +def make_class_dataset(source_root, class_to_idx): + """ + Returns (source, classnum, feature) + """ + imagepairs = [] + source_root = os.path.expanduser(source_root) + for path in walk_image_files(source_root): + classname = os.path.basename(os.path.dirname(path)) + imagepairs.append((path, 0)) + return imagepairs + + +def npy_loader(path): + img = np.load(path) + + if img.dtype == np.uint8: + img = img.astype(np.float32) + img = img / 127.5 - 1. + elif img.dtype == np.float32: + img = img * 2 - 1. + else: + raise NotImplementedError + + img = torch.Tensor(img) + if len(img.size()) == 4: + img.squeeze_(0) + + return img diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/logger.py b/pytorch_pretrained_gans/self_conditioned/gan_training/logger.py new file mode 100644 index 0000000..be67cd3 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/logger.py @@ -0,0 +1,96 @@ +import pickle +import os +import torchvision +import copy + + +class Logger(object): + def __init__(self, + log_dir='./logs', + img_dir='./imgs', + monitoring=None, + monitoring_dir=None): + self.stats = dict() + self.log_dir = log_dir + self.img_dir = img_dir + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + if not os.path.exists(img_dir): + os.makedirs(img_dir) + + if not (monitoring is None or monitoring == 'none'): + self.setup_monitoring(monitoring, monitoring_dir) + else: + self.monitoring = None + self.monitoring_dir = None + + def setup_monitoring(self, monitoring, monitoring_dir=None): + self.monitoring = monitoring + self.monitoring_dir = monitoring_dir + + if monitoring == 'telemetry': + import telemetry + self.tm = telemetry.ApplicationTelemetry() + if self.tm.get_status() == 0: + print('Telemetry successfully connected.') + elif monitoring == 'tensorboard': + import tensorboardX + self.tb = tensorboardX.SummaryWriter(monitoring_dir) + else: + raise NotImplementedError('Monitoring tool "%s" not supported!' % + monitoring) + + def add(self, category, k, v, it): + if category not in self.stats: + self.stats[category] = {} + + if k not in self.stats[category]: + self.stats[category][k] = [] + + self.stats[category][k].append((it, v)) + + k_name = '%s/%s' % (category, k) + if self.monitoring == 'telemetry': + self.tm.metric_push_async({'metric': k_name, 'value': v, 'it': it}) + elif self.monitoring == 'tensorboard': + self.tb.add_scalar(k_name, v, it) + + def add_imgs(self, imgs, class_name, it): + outdir = os.path.join(self.img_dir, class_name) + if not os.path.exists(outdir): + os.makedirs(outdir) + outfile = os.path.join(outdir, '%08d.png' % it) + + imgs = imgs / 2 + 0.5 + imgs = torchvision.utils.make_grid(imgs) + torchvision.utils.save_image(copy.deepcopy(imgs), outfile, nrow=8) + + if self.monitoring == 'tensorboard': + self.tb.add_image(class_name, copy.deepcopy(imgs), it) + + def get_last(self, category, k, default=0.): + if category not in self.stats: + return default + elif k not in self.stats[category]: + return default + else: + return self.stats[category][k][-1][1] + + def save_stats(self, filename): + filename = os.path.join(self.log_dir, filename) + with open(filename, 'wb') as f: + pickle.dump(self.stats, f) + + def load_stats(self, filename): + filename = os.path.join(self.log_dir, filename) + if not os.path.exists(filename): + print('Warning: file "%s" does not exist!' % filename) + return + + try: + with open(filename, 'rb') as f: + self.stats = pickle.load(f) + except EOFError: + print('Warning: log file corrupted!') diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/__init__.py b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/__init__.py new file mode 100644 index 0000000..dd7fd53 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/__init__.py @@ -0,0 +1,5 @@ +from gan_training.metrics.inception_score import inception_score + +__all__ = [ + inception_score +] diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/clustering_metrics.py b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/clustering_metrics.py new file mode 100644 index 0000000..fbbe336 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/clustering_metrics.py @@ -0,0 +1,41 @@ +def warn(*args, **kwargs): + pass + + +import warnings +warnings.warn = warn + +from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score +from sklearn import metrics + +import numpy as np + + +def nmi(inferred, gt): + return normalized_mutual_info_score(inferred, gt) + + +def acc(inferred, gt): + gt = gt.astype(np.int64) + assert inferred.size == gt.size + D = max(inferred.max(), gt.max()) + 1 + w = np.zeros((D, D), dtype=np.int64) + for i in range(inferred.size): + w[inferred[i], gt[i]] += 1 + from sklearn.utils.linear_assignment_ import linear_assignment + ind = linear_assignment(w.max() - w) + return sum([w[i, j] for i, j in ind]) * 1.0 / inferred.size + + +def purity_score(y_true, y_pred): + contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred) + return np.sum(np.amax(contingency_matrix, + axis=0)) / np.sum(contingency_matrix) + + +def ari(inferred, gt): + return adjusted_rand_score(gt, inferred) + + +def homogeneity(inferred, gt): + return homogeneity_score(gt, inferred) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/fid.py b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/fid.py new file mode 100644 index 0000000..319cd15 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/fid.py @@ -0,0 +1,304 @@ +from __future__ import absolute_import, division, print_function +import numpy as np +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import tensorflow as tf +from scipy import linalg +import pathlib +import urllib +from tqdm import tqdm +import warnings + + +def check_or_download_inception(inception_path): + ''' Checks if the path to the inception file is valid, or downloads + the file if it is not present. ''' + INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' + if inception_path is None: + inception_path = '/tmp' + inception_path = pathlib.Path(inception_path) + model_file = inception_path / 'classify_image_graph_def.pb' + if not model_file.exists(): + print("Downloading Inception model") + from urllib import request + import tarfile + fn, _ = request.urlretrieve(INCEPTION_URL) + with tarfile.open(fn, mode='r') as f: + f.extract('classify_image_graph_def.pb', str(model_file.parent)) + return str(model_file) + + +def create_inception_graph(pth): + """Creates a graph from saved GraphDef file.""" + # Creates graph from saved graph_def.pb. + with tf.io.gfile.GFile(pth, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='FID_Inception_Net') + + +def calculate_activation_statistics(images, + sess, + batch_size=200, + verbose=False): + """Calculation of the statistics used by the FID. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 255. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the available hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the incption model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the incption model. + """ + act = get_activations(images, sess, batch_size, verbose) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# code for handling inception net derived from +# https://github.com/openai/improved-gan/blob/master/inception_score/model.py +def _get_inception_layer(sess): + """Prepares inception net for batched usage and returns pool_3 layer. """ + layername = 'FID_Inception_Net/pool_3:0' + pool3 = sess.graph.get_tensor_by_name(layername) + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + if shape._dims != []: + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__['_shape_val'] = tf.TensorShape(new_shape) + return pool3 + + +#------------------------------------------------------------------------------- + + +def get_activations(images, sess, batch_size=200, verbose=False): + """Calculates the activations of the pool_3 layer for all images. + Params: + -- images : Numpy array of dimension (n_images, hi, wi, 3). The values + must lie between 0 and 256. + -- sess : current session + -- batch_size : the images numpy array is split into batches with batch size + batch_size. A reasonable batch size depends on the disposable hardware. + -- verbose : If set to True and parameter out_step is given, the number of calculated + batches is reported. + Returns: + -- A numpy array of dimension (num images, 2048) that contains the + activations of the given tensor when feeding inception with the query tensor. + """ + inception_layer = _get_inception_layer(sess) + n_images = images.shape[0] + if batch_size > n_images: + print( + "warning: batch size is bigger than the data size. setting batch size to data size" + ) + batch_size = n_images + n_batches = n_images // batch_size + pred_arr = np.empty((n_images, 2048)) + for i in tqdm(range(n_batches)): + if verbose: + print("\rPropagating batch %d/%d" % (i + 1, n_batches), + end="", + flush=True) + start = i * batch_size + + if start + batch_size < n_images: + end = start + batch_size + else: + end = n_images + + batch = images[start:end] + pred = sess.run(inception_layer, + {'FID_Inception_Net/ExpandDims:0': batch}) + pred_arr[start:end] = pred.reshape(batch_size, -1) + if verbose: + print(" done") + return pred_arr + + +#------------------------------------------------------------------------------- + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * tr_covmean + + +def compute_fid_from_npz(path): + print(path) + with np.load(path) as data: + fake_imgs = data['fake'] + + name = None + for name in ['imagenet', 'cifar', 'places']: + if name in path: + real_imgs = name + break + print('Inferred name', name) + if name is None: + real_imgs = data['real'] + + if fake_imgs.shape[0] < 1000: return 0 + + inception_path = check_or_download_inception(None) + create_inception_graph(inception_path) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m1, s1 = calculate_activation_statistics(fake_imgs, sess) + if isinstance(real_imgs, str): + print(f'using cached image stats for {real_imgs}') + with np.load(precomputed_stats[real_imgs]) as data: + m2, s2 = data['m'], data['s'] + else: + print('computing real images stats from scratch') + m2, s2 = calculate_activation_statistics(real_imgs, sess) + + return calculate_frechet_distance(m1, s1, m2, s2) + +precomputed_stats = { + 'places': + 'output/places_gt_stats.npz', + 'imagenet': + 'output/imagenet_gt_stats.npz', + 'cifar': + 'output/cifar_gt_stats.npz' +} + + +def compute_fid_from_imgs(fake_imgs, real_imgs): + inception_path = check_or_download_inception(None) + create_inception_graph(inception_path) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m1, s1 = calculate_activation_statistics(fake_imgs, sess) + if isinstance(real_imgs, str): + with np.load(precomputed_stats[real_imgs]) as data: + m2, s2 = data['m'], data['s'] + else: + m2, s2 = calculate_activation_statistics(real_imgs, sess) + return calculate_frechet_distance(m1, s1, m2, s2) + +def compute_stats(exp_path): + #TODO: a bit hacky + if 'places' in exp_path and not os.path.exists(precomputed_stats['places']): + with np.load('output/places_gt_imgs.npz') as data_real: + real_imgs = data_real['real'] + print('loaded real places images', real_imgs.shape) + inception_path = check_or_download_inception(None) + create_inception_graph(inception_path) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m, s = calculate_activation_statistics(real_imgs, sess) + np.savez(precomputed_stats['places'], m=m, s=s) + + if 'imagenet' in exp_path and not os.path.exists(precomputed_stats['imagenet']): + with np.load('output/imagenet_gt_imgs.npz') as data_real: + real_imgs = data_real['real'] + print('loaded real imagenet images', real_imgs.shape) + inception_path = check_or_download_inception(None) + create_inception_graph(inception_path) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m, s = calculate_activation_statistics(real_imgs, sess) + np.savez(precomputed_stats['imagenet'], m=m, s=s) + + if 'cifar' in exp_path and not os.path.exists(precomputed_stats['cifar']): + with np.load('output/cifar_gt_imgs.npz') as data_real: + real_imgs = data_real['real'] + print('loaded real cifar images', real_imgs.shape) + inception_path = check_or_download_inception(None) + create_inception_graph(inception_path) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + m, s = calculate_activation_statistics(real_imgs, sess) + np.savez(precomputed_stats['cifar'], m=m, s=s) + +if __name__ == '__main__': + import argparse + import json + + parser = argparse.ArgumentParser('compute TF FID') + parser.add_argument('--samples', help='path to samples') + parser.add_argument('--it', type=str, help='path to samples') + parser.add_argument('--results_dir', help='path to results_dir') + args = parser.parse_args() + + it = args.it + results_dir = args.results_dir + + compute_stats(args.samples) + mean = compute_fid_from_npz(args.samples) + print(f'FID: {mean}') + + if args.results_dir is not None: + with open(os.path.join(args.results_dir, 'fid_results.json')) as f: + fid_results = json.load(f) + + fid_results[it] = mean + print(f'{results_dir} iteration {it} FID: {mean}') + + with open(os.path.join(args.results_dir, 'fid_results.json'), 'w') as f: + f.write(json.dumps(fid_results)) \ No newline at end of file diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/inception_score.py b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/inception_score.py new file mode 100644 index 0000000..a2a593e --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/inception_score.py @@ -0,0 +1,66 @@ +import torch +from torch import nn +from torch.nn import functional as F +import torch.utils.data + +from torchvision.models.inception import inception_v3 + +import numpy as np +from scipy.stats import entropy + + +def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1): + """Computes the inception score of the generated images imgs + + Args: + imgs: Torch dataset of (3xHxW) numpy images normalized in the + range [-1, 1] + cuda: whether or not to run on GPU + batch_size: batch size for feeding into Inception v3 + splits: number of splits + """ + N = len(imgs) + + assert batch_size > 0 + assert N > batch_size + + # Set up dataloader + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + + # Load inception model + inception_model = inception_v3(pretrained=True, transform_input=False) + inception_model = inception_model.to(device) + inception_model.eval() + up = nn.Upsample(size=(299, 299), mode='bilinear').to(device) + + def get_pred(x): + with torch.no_grad(): + if resize: + x = up(x) + x = inception_model(x) + out = F.softmax(x, dim=-1) + out = out.cpu().numpy() + return out + + # Get predictions + preds = np.zeros((N, 1000)) + + for i, batch in enumerate(dataloader, 0): + batchv = batch.to(device) + batch_size_i = batch.size()[0] + + preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) + + # Now compute the mean kl-div + split_scores = [] + + for k in range(splits): + part = preds[k * (N // splits):(k + 1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/LICENSE b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/README.md b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/README.md new file mode 100644 index 0000000..07c276f --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/README.md @@ -0,0 +1,23 @@ +Inception Score +===================================== + +A new Tensorflow implementation of the "Inception Score" (IS) for the evaluation of generative models, with a bug raised in [https://github.com/openai/improved-gan/issues/29](https://github.com/openai/improved-gan/issues/29) fixed. + +## Major Dependency +- `tensorflow >= 1.14` + +## Features +- Fast, easy-to-use and memory-efficient, written in a way that is similar to the original implementation +- No prior knowledge about Tensorflow is necessary if your are using CPU or GPU +- Makes use of [TF-GAN](https://github.com/tensorflow/gan) +- Downloads InceptionV1 automatically +- Compatible with both Python 2 and Python 3 + +## Usage +- If you are working with GPU, use `inception_score.py`; if you are working with TPU, use `inception_score_tpu.py` and pass a Tensorflow Session and a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy) as additional arguments. +- Call `get_inception_score(images, splits=10)`, where `images` is a numpy array with values ranging from 0 to 255 and shape in the form `[N, 3, HEIGHT, WIDTH]` where `N`, `HEIGHT` and `WIDTH` can be arbitrary. `dtype` of the images is recommended to be `np.uint8` to save CPU memory. +- A smaller `BATCH_SIZE` reduces GPU/TPU memory usage, but at the cost of a slight slowdown. +- If you want to compute a general "Classifier Score" with probabilities `preds` from another classifier, call `preds2score(preds, splits=10)`. `preds` can be a numpy array of arbitrary shape `[N, num_classes]`. +## Links +- The Inception Score was proposed in the paper [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) +- Code for the [Fréchet Inception Distance](https://github.com/tsc2017/Frechet-Inception-Distance) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/inception_score.py b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/inception_score.py new file mode 100644 index 0000000..9fac701 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/inception_score.py @@ -0,0 +1,116 @@ +''' +From https://github.com/tsc2017/Inception-Score +Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py + +Usage: + Call get_inception_score(images, splits=10) +Args: + images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory. + splits: The number of splits of the images, default is 10. +Returns: + Mean and standard deviation of the Inception Score across the splits. +''' + +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import tensorflow as tf +import functools +import numpy as np +import time +from tqdm import tqdm +from tensorflow.python.ops import array_ops +tfgan = tf.contrib.gan + +session=tf.compat.v1.InteractiveSession() + +# A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown +BATCH_SIZE = 64 +INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' +INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' + +# Run images through Inception. +inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None]) +def inception_logits(images = inception_images, num_splits = 1): + images = tf.transpose(images, [0, 2, 3, 1]) + size = 299 + images = tf.compat.v1.image.resize_bilinear(images, [size, size]) + generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) + logits = tf.map_fn( + fn = functools.partial( + tfgan.eval.run_inception, + default_graph_def_fn = functools.partial( + tfgan.eval.get_graph_def_from_url_tarball, + INCEPTION_URL, + INCEPTION_FROZEN_GRAPH, + os.path.basename(INCEPTION_URL)), + output_tensor = 'logits:0'), + elems = array_ops.stack(generated_images_list), + parallel_iterations = 8, + back_prop = False, + swap_memory = True, + name = 'RunClassifier') + logits = array_ops.concat(array_ops.unstack(logits), 0) + return logits + +logits=inception_logits() + +def get_inception_probs(inps): + n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) + preds = np.zeros([inps.shape[0], 1000], dtype = np.float32) + for i in tqdm(range(n_batches)): + inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1 + preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000] + preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) + return preds + +def preds2score(preds, splits=10): + scores = [] + for i in range(splits): + part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return np.mean(scores), np.std(scores) + +def get_inception_score(images, splits=10): + assert(type(images) == np.ndarray) + assert(len(images.shape) == 4) + assert(images.shape[1] == 3) + assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]' + print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits)) + start_time=time.time() + preds = get_inception_probs(images) + mean, std = preds2score(preds, splits) + print('Inception Score calculation time: %f s' % (time.time() - start_time)) + return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits. + +def compute_is_from_npz(path): + with np.load(path) as data: + fake_imgs = data['fake'] + fake_imgs = fake_imgs.transpose(0, 3, 1, 2) + print(fake_imgs.shape) + return get_inception_score(fake_imgs) + + +if __name__ == '__main__': + import argparse + import json + + parser = argparse.ArgumentParser('compute TF IS') + parser.add_argument('--samples', help='path to samples') + parser.add_argument('--it', type=str, help='path to samples') + parser.add_argument('--results_dir', help='path to results_dir') + args = parser.parse_args() + + it = args.it + results_dir = args.results_dir + mean, std = compute_is_from_npz(args.samples) + + with open(os.path.join(args.results_dir, 'is_results.json')) as f: + is_results = json.load(f) + + is_results[it] = float(mean) + print(f'{results_dir} iteration {it} IS: {mean}') + + with open(os.path.join(args.results_dir, 'is_results.json'), 'w') as f: + f.write(json.dumps(is_results)) \ No newline at end of file diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/__init__.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/__init__.py new file mode 100644 index 0000000..b16a68e --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/__init__.py @@ -0,0 +1,13 @@ +from . import (dcgan_deep, dcgan_shallow, resnet2) + +generator_dict = { + 'resnet2': resnet2.Generator, + 'dcgan_deep': dcgan_deep.Generator, + 'dcgan_shallow': dcgan_shallow.Generator +} + +discriminator_dict = { + 'resnet2': resnet2.Discriminator, + 'dcgan_deep': dcgan_deep.Discriminator, + 'dcgan_shallow': dcgan_shallow.Discriminator +} diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/blocks.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/blocks.py new file mode 100644 index 0000000..6464c7c --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/blocks.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn import functional as F + + +class ResnetBlock(nn.Module): + def __init__(self, + fin, + fout, + bn, + nclasses, + fhidden=None, + is_bias=True): + super().__init__() + # Attributes + self.is_bias = is_bias + self.learned_shortcut = (fin != fout) + self.fin = fin + self.fout = fout + if fhidden is None: + self.fhidden = min(fin, fout) + else: + self.fhidden = fhidden + # Submodules + self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) + self.conv_1 = nn.Conv2d(self.fhidden, + self.fout, + 3, + stride=1, + padding=1, + bias=is_bias) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(self.fin, + self.fout, + 1, + stride=1, + padding=0, + bias=False) + self.bn0 = bn(self.fin, nclasses) + self.bn1 = bn(self.fhidden, nclasses) + + def forward(self, x, y): + x_s = self._shortcut(x) + dx = self.conv_0(actvn(self.bn0(x, y))) + dx = self.conv_1(actvn(self.bn1(dx, y))) + out = x_s + 0.1 * dx + + return out + + def _shortcut(self, x): + if self.learned_shortcut: + x_s = self.conv_s(x) + else: + x_s = x + return x_s + + +def actvn(x): + out = F.leaky_relu(x, 2e-1) + return out + + +class LatentEmbeddingConcat(nn.Module): + ''' projects class embedding onto hypersphere and returns the concat of the latent and the class embedding ''' + + def __init__(self, nlabels, embed_dim): + super().__init__() + self.embedding = nn.Embedding(nlabels, embed_dim) + + def forward(self, z, y): + assert (y.size(0) == z.size(0)) + yembed = self.embedding(y) + yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) + yz = torch.cat([z, yembed], dim=1) + return yz + + +class NormalizeLinear(nn.Module): + def __init__(self, act_dim, k_value): + super().__init__() + self.lin = nn.Linear(act_dim, k_value) + + def normalize(self): + self.lin.weight.data = F.normalize(self.lin.weight.data, p=2, dim=1) + + def forward(self, x): + self.normalize() + return self.lin(x) + + +class Identity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, inp, *args, **kwargs): + return inp + + +class LinearConditionalMaskLogits(nn.Module): + ''' runs activated logits through fc and masks out the appropriate discriminator score according to class number''' + + def __init__(self, nc, nlabels): + super().__init__() + self.fc = nn.Linear(nc, nlabels) + + def forward(self, inp, y=None, take_best=False, get_features=False): + out = self.fc(inp) + if get_features: return out + + if not take_best: + y = y.view(-1) + index = Variable(torch.LongTensor(range(out.size(0)))) + if y.is_cuda: + index = index.cuda() + return out[index, y] + else: + # high activation means real, so take the highest activations + best_logits, _ = out.max(dim=1) + return best_logits + + +class ProjectionDiscriminatorLogits(nn.Module): + ''' takes in activated flattened logits before last linear layer and implements https://arxiv.org/pdf/1802.05637.pdf ''' + + def __init__(self, nc, nlabels): + super().__init__() + self.fc = nn.Linear(nc, 1) + self.embedding = nn.Embedding(nlabels, nc) + self.nlabels = nlabels + + def forward(self, x, y, take_best=False): + output = self.fc(x) + + if not take_best: + label_info = torch.sum(self.embedding(y) * x, dim=1, keepdim=True) + return (output + label_info).view(x.size(0)) + else: + #TODO: this may be computationally expensive, maybe we want to do the global pooling first to reduce x's size + index = torch.LongTensor(range(self.nlabels)).cuda() + labels = index.repeat((x.size(0), )) + x = x.repeat_interleave(self.nlabels, dim=0) + label_info = torch.sum(self.embedding(labels) * x, + dim=1, + keepdim=True).view(output.size(0), + self.nlabels) + # high activation means real, so take the highest activations + best_logits, _ = label_info.max(dim=1) + return output.view(output.size(0)) + best_logits + + +class LinearUnconditionalLogits(nn.Module): + ''' standard discriminator logit layer ''' + + def __init__(self, nc): + super().__init__() + self.fc = nn.Linear(nc, 1) + + def forward(self, inp, y, take_best=False): + assert (take_best == False) + + out = self.fc(inp) + return out.view(out.size(0)) + + +class Reshape(nn.Module): + def __init__(self, *shape): + super().__init__() + self.shape = shape + + def forward(self, x): + batch_size = x.shape[0] + return x.view(*((batch_size, ) + self.shape)) + + +class ConditionalBatchNorm2d(nn.Module): + ''' from https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 ''' + + def __init__(self, num_features, num_classes): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm2d(num_features, affine=False) + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[:, :num_features].normal_( + 1, 0.02) # Initialize scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_( + ) # Initialize bias at 0 + + def forward(self, x, y): + out = self.bn(x) + gamma, beta = self.embed(y).chunk(2, 1) + out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( + -1, self.num_features, 1, 1) + return out + + +class BatchNorm2d(nn.Module): + ''' identical to nn.BatchNorm2d but takes in y input that is ignored ''' + + def __init__(self, nc, nchannels, **kwargs): + super().__init__() + self.bn = nn.BatchNorm2d(nc) + + def forward(self, x, y): + return self.bn(x) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_deep.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_deep.py new file mode 100644 index 0000000..017ec0a --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_deep.py @@ -0,0 +1,139 @@ +import torch +from torch import nn +from torch.nn import functional as F +import torch.utils.data +import torch.utils.data.distributed +from . import blocks + + +class Generator(nn.Module): + def __init__(self, + nlabels, + conditioning, + z_dim=128, + nc=3, + ngf=64, + embed_dim=256, + **kwargs): + super(Generator, self).__init__() + + assert conditioning != 'unconditional' or nlabels == 1 + + if conditioning == 'embedding': + self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim) + self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8) + elif conditioning == 'unconditional': + self.get_latent = blocks.Identity() + self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8) + else: + raise NotImplementedError( + f"{conditioning} not implemented for generator") + + bn = blocks.BatchNorm2d + + self.nlabels = nlabels + + self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1) + self.bn1 = bn(ngf * 4, nlabels) + + self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1) + self.bn2 = bn(ngf * 2, nlabels) + + self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1) + self.bn3 = bn(ngf, nlabels) + + self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh()) + + def forward(self, input, y): + y = y.clamp(None, self.nlabels - 1) + out = self.get_latent(input, y) + + out = self.fc(out) + out = out.view(out.size(0), -1, 4, 4) + out = F.relu(self.bn1(self.conv1(out), y)) + out = F.relu(self.bn2(self.conv2(out), y)) + out = F.relu(self.bn3(self.conv3(out), y)) + return self.conv_out(out) + + +class Discriminator(nn.Module): + def __init__(self, + nlabels, + conditioning, + nc=3, + ndf=64, + pack_size=1, + features='penultimate', + **kwargs): + + super(Discriminator, self).__init__() + + assert conditioning != 'unconditional' or nlabels == 1 + + self.nlabels = nlabels + + self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 3, 1, 1), nn.LeakyReLU(0.1)) + self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf, 4, 2, 1), nn.LeakyReLU(0.1)) + self.conv3 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 3, 1, 1), nn.LeakyReLU(0.1)) + self.conv4 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1), nn.LeakyReLU(0.1)) + self.conv5 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1), nn.LeakyReLU(0.1)) + self.conv6 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1), nn.LeakyReLU(0.1)) + self.conv7 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1), nn.LeakyReLU(0.1)) + + if conditioning == 'mask': + self.fc_out = blocks.LinearConditionalMaskLogits( + ndf * 8 * 4 * 4, nlabels) + elif conditioning == 'unconditional': + self.fc_out = blocks.LinearUnconditionalLogits( + ndf * 8 * 4 * 4) + else: + raise NotImplementedError( + f"{conditioning} not implemented for discriminator") + + self.features = features + self.pack_size = pack_size + print(f'Getting features from {self.features}') + + def stack(self, x): + # pacgan + nc = self.pack_size + assert (x.size(0) % nc == 0) + if nc == 1: + return x + x_new = [] + for i in range(x.size(0) // nc): + imgs_to_stack = x[i * nc:(i + 1) * nc] + x_new.append(torch.cat([t for t in imgs_to_stack], dim=0)) + return torch.stack(x_new) + + def forward(self, input, y=None, get_features=False): + input = self.stack(input) + out = self.conv1(input) + out = self.conv2(out) + out = self.conv3(out) + out = self.conv4(out) + out = self.conv5(out) + out = self.conv6(out) + out = self.conv7(out) + + if get_features and self.features == "penultimate": + return out.view(out.size(0), -1) + if get_features and self.features == "summed": + return out.view(out.size(0), out.size(1), -1).sum(dim=2) + + out = out.view(out.size(0), -1) + y = y.clamp(None, self.nlabels - 1) + result = self.fc_out(out, y) + assert (len(result.shape) == 1) + return result + + +if __name__ == '__main__': + z = torch.zeros((1, 128)) + g = Generator() + x = torch.zeros((1, 3, 32, 32)) + d = Discriminator() + + g(z) + d(g(z)) + d(x) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_shallow.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_shallow.py new file mode 100644 index 0000000..623ff58 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_shallow.py @@ -0,0 +1,134 @@ +import torch +from torch import nn +from torch.nn import functional as F +import torch.utils.data +import torch.utils.data.distributed +from . import blocks + + +class Generator(nn.Module): + def __init__(self, + nlabels, + conditioning, + z_dim=128, + nc=3, + ngf=64, + embed_dim=256, + **kwargs): + super(Generator, self).__init__() + + assert conditioning != 'unconditional' or nlabels == 1 + + if conditioning == 'embedding': + self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim) + self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8) + elif conditioning == 'unconditional': + self.get_latent = blocks.Identity() + self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8) + else: + raise NotImplementedError( + f"{conditioning} not implemented for generator") + + bn = blocks.BatchNorm2d + + self.nlabels = nlabels + + self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1) + self.bn1 = bn(ngf * 4, nlabels) + + self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1) + self.bn2 = bn(ngf * 2, nlabels) + + self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1) + self.bn3 = bn(ngf, nlabels) + + self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh()) + + def forward(self, input, y): + y = y.clamp(None, self.nlabels - 1) + + out = self.get_latent(input, y) + out = self.fc(out) + + out = out.view(out.size(0), -1, 4, 4) + out = F.relu(self.bn1(self.conv1(out), y)) + out = F.relu(self.bn2(self.conv2(out), y)) + out = F.relu(self.bn3(self.conv3(out), y)) + return self.conv_out(out) + + +class Discriminator(nn.Module): + def __init__(self, + nlabels, + conditioning, + features='penultimate', + pack_size=1, + nc=3, + ndf=64, + **kwargs): + super(Discriminator, self).__init__() + + assert conditioning != 'unconditional' or nlabels == 1 + + self.nlabels = nlabels + + self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 4, 2, 1), + nn.BatchNorm2d(ndf), + nn.LeakyReLU(0.2, inplace=True)) + self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1), + nn.BatchNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True)) + self.conv3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), + nn.BatchNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True)) + self.conv4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1), + nn.BatchNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True)) + + if conditioning == 'mask': + self.fc_out = blocks.LinearConditionalMaskLogits(ndf * 8 * 4, nlabels) + elif conditioning == 'unconditional': + self.fc_out = blocks.LinearUnconditionalLogits(ndf * 8 * 4) + else: + raise NotImplementedError( + f"{conditioning} not implemented for discriminator") + + self.pack_size = pack_size + self.features = features + print(f'Getting features from {self.features}') + + def stack(self, x): + # pacgan + nc = self.pack_size + if nc == 1: + return x + x_new = [] + for i in range(x.size(0) // nc): + imgs_to_stack = x[i * nc:(i + 1) * nc] + x_new.append(torch.cat([t for t in imgs_to_stack], dim=0)) + return torch.stack(x_new) + + def forward(self, input, y=None, get_features=False): + input = self.stack(input) + out = self.conv1(input) + out = self.conv2(out) + out = self.conv3(out) + out = self.conv4(out) + out = out.view(out.size(0), -1) + if get_features: + return out.view(out.size(0), -1) + y = y.clamp(None, self.nlabels - 1) + result = self.fc_out(out, y) + assert (len(result.shape) == 1) + return result + + +if __name__ == '__main__': + z = torch.zeros((1, 128)) + g = Generator() + x = torch.zeros((1, 3, 32, 32)) + d = Discriminator() + + g(z) + d(g(z)) + d(x) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2.py new file mode 100644 index 0000000..ac2f98c --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2.py @@ -0,0 +1,187 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable +import torch.utils.data +import torch.utils.data.distributed + +from . import blocks +from .blocks import ResnetBlock +from torch.nn.utils.spectral_norm import spectral_norm + + +class Generator(nn.Module): + def __init__(self, + z_dim, + nlabels, + size, + conditioning, + embed_size=256, + nfilter=64, + **kwargs): + super().__init__() + s0 = self.s0 = size // 32 + nf = self.nf = nfilter + self.nlabels = nlabels + self.z_dim = z_dim + + assert conditioning != 'unconditional' or nlabels == 1 + + if conditioning == 'embedding': + self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_size) + self.fc = nn.Linear(z_dim + embed_size, 16 * nf * s0 * s0) + elif conditioning == 'unconditional': + self.get_latent = blocks.Identity() + self.fc = nn.Linear(z_dim, 16 * nf * s0 * s0) + else: + raise NotImplementedError( + f"{conditioning} not implemented for generator") + + # either use conditional batch norm, or use no batch norm + bn = blocks.Identity + + self.resnet_0_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + self.resnet_0_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + + self.resnet_1_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + self.resnet_1_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + + self.resnet_2_0 = ResnetBlock(16 * nf, 8 * nf, bn, nlabels) + self.resnet_2_1 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels) + + self.resnet_3_0 = ResnetBlock(8 * nf, 4 * nf, bn, nlabels) + self.resnet_3_1 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels) + + self.resnet_4_0 = ResnetBlock(4 * nf, 2 * nf, bn, nlabels) + self.resnet_4_1 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels) + + self.resnet_5_0 = ResnetBlock(2 * nf, 1 * nf, bn, nlabels) + self.resnet_5_1 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels) + + self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) + + def forward(self, z, y): + y = y.clamp(None, self.nlabels - 1) + out = self.get_latent(z, y) + + out = self.fc(out) + + out = out.view(z.size(0), 16 * self.nf, self.s0, self.s0) + + out = self.resnet_0_0(out, y) + out = self.resnet_0_1(out, y) + + out = F.interpolate(out, scale_factor=2) + out = self.resnet_1_0(out, y) + out = self.resnet_1_1(out, y) + + out = F.interpolate(out, scale_factor=2) + out = self.resnet_2_0(out, y) + out = self.resnet_2_1(out, y) + + out = F.interpolate(out, scale_factor=2) + out = self.resnet_3_0(out, y) + out = self.resnet_3_1(out, y) + + out = F.interpolate(out, scale_factor=2) + out = self.resnet_4_0(out, y) + out = self.resnet_4_1(out, y) + + out = F.interpolate(out, scale_factor=2) + out = self.resnet_5_0(out, y) + out = self.resnet_5_1(out, y) + + out = self.conv_img(actvn(out)) + out = torch.tanh(out) + + return out + + +class Discriminator(nn.Module): + def __init__(self, + nlabels, + size, + conditioning, + nfilter=64, + features='penultimate', + **kwargs): + super().__init__() + s0 = self.s0 = size // 32 + nf = self.nf = nfilter + self.nlabels = nlabels + + assert conditioning != 'unconditional' or nlabels == 1 + bn = blocks.Identity + + self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1) + + self.resnet_0_0 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels) + self.resnet_0_1 = ResnetBlock(1 * nf, 2 * nf, bn, nlabels) + + self.resnet_1_0 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels) + self.resnet_1_1 = ResnetBlock(2 * nf, 4 * nf, bn, nlabels) + + self.resnet_2_0 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels) + self.resnet_2_1 = ResnetBlock(4 * nf, 8 * nf, bn, nlabels) + + self.resnet_3_0 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels) + self.resnet_3_1 = ResnetBlock(8 * nf, 16 * nf, bn, nlabels) + + self.resnet_4_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + self.resnet_4_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + + self.resnet_5_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + self.resnet_5_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels) + + if conditioning == 'mask': + self.fc_out = blocks.LinearConditionalMaskLogits( + 16 * nf * s0 * s0, nlabels) + elif conditioning == 'unconditional': + self.fc_out = blocks.LinearUnconditionalLogits(16 * nf * s0 * s0) + else: + raise NotImplementedError( + f"{conditioning} not implemented for discriminator") + + self.features = features + + def forward(self, x, y=None, get_features=False): + batch_size = x.size(0) + if y is not None: + y = y.clamp(None, self.nlabels - 1) + + out = self.conv_img(x) + + out = self.resnet_0_0(out, y) + out = self.resnet_0_1(out, y) + out = F.avg_pool2d(out, 3, stride=2, padding=1) + out = self.resnet_1_0(out, y) + out = self.resnet_1_1(out, y) + out = F.avg_pool2d(out, 3, stride=2, padding=1) + out = self.resnet_2_0(out, y) + out = self.resnet_2_1(out, y) + out = F.avg_pool2d(out, 3, stride=2, padding=1) + out = self.resnet_3_0(out, y) + out = self.resnet_3_1(out, y) + out = F.avg_pool2d(out, 3, stride=2, padding=1) + out = self.resnet_4_0(out, y) + out = self.resnet_4_1(out, y) + out = F.avg_pool2d(out, 3, stride=2, padding=1) + out = self.resnet_5_0(out, y) + out = self.resnet_5_1(out, y) + out = actvn(out) + + if get_features and self.features == 'summed': + return out.view(out.size(0), out.size(1), -1).sum(dim=2) + + out = out.view(batch_size, 16 * self.nf * self.s0 * self.s0) + + if get_features: + return out.view(batch_size, -1) + result = self.fc_out(out, y) + assert (len(result.shape) == 1) + return result + + +def actvn(x): + out = F.leaky_relu(x, 2e-1) + return out diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2s.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2s.py new file mode 100644 index 0000000..4599b82 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet2s.py @@ -0,0 +1,186 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable +import torch.utils.data +import torch.utils.data.distributed +from collections import OrderedDict + + +class Reshape(nn.Module): + def __init__(self, *shape): + super().__init__() + self.shape = shape + + def forward(self, x): + batch_size = x.shape[0] + return x.view(*((batch_size, ) + self.shape)) + + +class Generator(nn.Module): + ''' + Perfectly equivalent to resnet2.Generator (can load state dicts + from that class), but organizes layers as a sequence for more + automatic inversion. + ''' + + def __init__(self, + z_dim, + nlabels, + size, + embed_size=256, + nfilter=64, + use_class_labels=False, + **kwargs): + super().__init__() + s0 = self.s0 = size // 32 + nf = self.nf = nfilter + self.z_dim = z_dim + self.use_class_labels = use_class_labels + # Submodules + if use_class_labels: + self.condition = ConditionGen(z_dim, nlabels, embed_size) + latent_dim = self.condition.latent_dim + else: + latent_dim = z_dim + + self.layers = nn.Sequential( + OrderedDict([('fc', nn.Linear(latent_dim, 16 * nf * s0 * s0)), + ('reshape', Reshape(16 * self.nf, self.s0, self.s0)), + ('resnet_0_0', ResnetBlock(16 * nf, 16 * nf)), + ('resnet_0_1', ResnetBlock(16 * nf, 16 * nf)), + ('upsample_1', nn.Upsample(scale_factor=2)), + ('resnet_1_0', ResnetBlock(16 * nf, 16 * nf)), + ('resnet_1_1', ResnetBlock(16 * nf, 16 * nf)), + ('upsample_2', nn.Upsample(scale_factor=2)), + ('resnet_2_0', ResnetBlock(16 * nf, 8 * nf)), + ('resnet_2_1', ResnetBlock(8 * nf, 8 * nf)), + ('upsample_3', nn.Upsample(scale_factor=2)), + ('resnet_3_0', ResnetBlock(8 * nf, 4 * nf)), + ('resnet_3_1', ResnetBlock(4 * nf, 4 * nf)), + ('upsample_4', nn.Upsample(scale_factor=2)), + ('resnet_4_0', ResnetBlock(4 * nf, 2 * nf)), + ('resnet_4_1', ResnetBlock(2 * nf, 2 * nf)), + ('upsample_5', nn.Upsample(scale_factor=2)), + ('resnet_5_0', ResnetBlock(2 * nf, 1 * nf)), + ('resnet_5_1', ResnetBlock(1 * nf, 1 * nf)), + ('img_relu', nn.LeakyReLU(2e-1)), + ('conv_img', nn.Conv2d(nf, 3, 3, padding=1)), + ('tanh', nn.Tanh())])) + + def forward(self, z, y=None): + assert (y is None or z.size(0) == y.size(0)) + assert (not self.use_class_labels or y is not None) + batch_size = z.size(0) + if self.use_class_labels: + z = self.condition(z, y) + return self.layers(z) + + def load_v2_state_dict(self, state_dict): + converted = {} + for k, v in state_dict.items(): + if 'module.' in k: k = k.split('module.')[1] + if k.startswith('embedding'): + k = 'condition.' + k + elif k == 'get_latent.embedding.weight': + k = 'condition.embedding.weight' + else: + k = 'layers.' + k + converted[k] = v + self.load_state_dict(converted) + + +class ConditionGen(nn.Module): + def __init__(self, z_dim, nlabels, embed_size=256): + super().__init__() + self.embedding = nn.Embedding(nlabels, embed_size) + self.latent_dim = z_dim + embed_size + self.z_dim = z_dim + self.nlabels = nlabels + self.embed_size = embed_size + + def forward(self, z, y): + assert (z.size(0) == y.size(0)) + batch_size = z.size(0) + if y.dtype is torch.int64: + yembed = self.embedding(y) + else: + yembed = y + yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) + return torch.cat([z, yembed], dim=1) + + +def convert_from_resnet2_generator(gen): + nlabels, embed_size = 0, 0 + use_class_labels = False + if hasattr(gen, 'embedding'): + # new version does not have gen.use_class_labels.. + nlabels = gen.embedding.num_embeddings + embed_size = gen.embedding.embedding_dim + use_class_labels = True + if hasattr(gen, 'get_latent'): + # new version does not have gen.use_class_labels.. + nlabels = gen.get_latent.embedding.num_embeddings + embed_size = gen.get_latent.embedding.embedding_dim + use_class_labels = True + size = gen.s0 * 32 + newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf, + use_class_labels) + newgen.load_v2_state_dict(gen.state_dict()) + return newgen + + +class ResnetBlock(nn.Module): + def __init__(self, fin, fout, fhidden=None, is_bias=True): + super().__init__() + # Attributes + self.is_bias = is_bias + self.learned_shortcut = (fin != fout) + self.fin = fin + self.fout = fout + if fhidden is None: + self.fhidden = min(fin, fout) + else: + self.fhidden = fhidden + + # Submodules + self.conv_0 = nn.Conv2d(self.fin, + self.fhidden, + kernel_size=3, + stride=1, + padding=1) + self.conv_1 = nn.Conv2d(self.fhidden, + self.fout, + kernel_size=3, + stride=1, + padding=1, + bias=is_bias) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(self.fin, + self.fout, + kernel_size=1, + stride=1, + padding=0, + bias=False) + + def forward(self, x): + x_s = self._shortcut(x) + dx = self.conv_0(actvn(x)) + dx = self.conv_1(actvn(dx)) + out = x_s + 0.1 * dx + + return out + + def _shortcut(self, x): + if self.learned_shortcut: + x_s = self.conv_s(x) + else: + x_s = x + return x_s + + +def actvn(x): + out = F.leaky_relu(x, 2e-1) + return out + + diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet3.py b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet3.py new file mode 100644 index 0000000..20ba133 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/models/resnet3.py @@ -0,0 +1,161 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable +import torch.utils.data +import torch.utils.data.distributed +from collections import OrderedDict + +class Generator(nn.Module): + ''' + Perfectly equivalent to resnet2.Generator (can load state dicts + from that class), but organizes layers as a sequence for more + automatic inversion. + ''' + def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, + use_class_labels=False, **kwargs): + super().__init__() + s0 = self.s0 = size // 32 + nf = self.nf = nfilter + self.z_dim = z_dim + self.use_class_labels = use_class_labels + + # Submodules + if use_class_labels: + self.condition = ConditionGen(z_dim, nlabels, embed_size) + latent_dim = self.condition.latent_dim + else: + latent_dim = z_dim + + self.layers = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(latent_dim, 16*nf*s0*s0)), + ('reshape', Reshape(16*self.nf, self.s0, self.s0)), + ('resnet_0_0', ResnetBlock(16*nf, 16*nf)), + ('resnet_0_1', ResnetBlock(16*nf, 16*nf)), + ('upsample_1', nn.Upsample(scale_factor=2)), + ('resnet_1_0', ResnetBlock(16*nf, 16*nf)), + ('resnet_1_1', ResnetBlock(16*nf, 16*nf)), + ('upsample_2', nn.Upsample(scale_factor=2)), + ('resnet_2_0', ResnetBlock(16*nf, 8*nf)), + ('resnet_2_1', ResnetBlock(8*nf, 8*nf)), + ('upsample_3', nn.Upsample(scale_factor=2)), + ('resnet_3_0', ResnetBlock(8*nf, 4*nf)), + ('resnet_3_1', ResnetBlock(4*nf, 4*nf)), + ('upsample_4', nn.Upsample(scale_factor=2)), + ('resnet_4_0', ResnetBlock(4*nf, 2*nf)), + ('resnet_4_1', ResnetBlock(2*nf, 2*nf)), + ('upsample_5', nn.Upsample(scale_factor=2)), + ('resnet_5_0', ResnetBlock(2*nf, 1*nf)), + ('resnet_5_1', ResnetBlock(1*nf, 1*nf)), + ('img_relu', nn.LeakyReLU(2e-1)), + ('conv_img', nn.Conv2d(nf, 3, 3, padding=1)), + ('tanh', nn.Tanh()) + ])) + + def forward(self, z, y=None): + assert(y is None or z.size(0) == y.size(0)) + assert(not self.use_class_labels or y is not None) + batch_size = z.size(0) + if self.use_class_labels: + z = self.condition(z, y) + return self.layers(z) + + def load_v2_state_dict(self, state_dict): + converted = {} + for k, v in state_dict.items(): + if k.startswith('embedding'): + k = 'condition.' + k + elif k == 'get_latent.embedding.weight': + k = 'condition.embedding.weight' + else: + k = 'layers.' + k + converted[k] = v + self.load_state_dict(converted) + +class Reshape(nn.Module): + def __init__(self, *shape): + super().__init__() + self.shape = shape + def forward(self, x): + batch_size = x.shape[0] + return x.view(*((batch_size,) + self.shape)) + +class ConditionGen(nn.Module): + def __init__(self, z_dim, nlabels, embed_size=256): + super().__init__() + self.embedding = nn.Embedding(nlabels, embed_size) + self.latent_dim = z_dim + embed_size + self.z_dim = z_dim + self.nlabels = nlabels + self.embed_size = embed_size + + def forward(self, z, y): + assert(z.size(0) == y.size(0)) + batch_size = z.size(0) + if y.dtype is torch.int64: + yembed = self.embedding(y) + else: + yembed = y + yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) + return torch.cat([z, yembed], dim=1) + +def convert_from_resnet2_generator(gen): + nlabels, embed_size = 0, 0 + + if hasattr(gen, 'get_latent'): + # new version does not have gen.use_class_labels.. + nlabels = gen.get_latent.embedding.num_embeddings + embed_size = gen.get_latent.embedding.embedding_dim + use_class_labels = True + elif gen.use_class_labels: + nlabels = gen.embedding.num_embeddings + embed_size = gen.embedding.embedding_dim + use_class_labels = True + + size = gen.s0 * 32 + newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf, use_class_labels) + newgen.load_v2_state_dict(gen.state_dict()) + return newgen + + +class ResnetBlock(nn.Module): + def __init__(self, fin, fout, fhidden=None, is_bias=True): + super().__init__() + # Attributes + self.is_bias = is_bias + self.learned_shortcut = (fin != fout) + self.fin = fin + self.fout = fout + if fhidden is None: + self.fhidden = min(fin, fout) + else: + self.fhidden = fhidden + + # Submodules + self.conv_0 = nn.Conv2d(self.fin, self.fhidden, + kernel_size=3, stride=1, padding=1) + self.conv_1 = nn.Conv2d(self.fhidden, self.fout, + kernel_size=3, stride=1, padding=1, bias=is_bias) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(self.fin, self.fout, + kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, x): + x_s = self._shortcut(x) + dx = self.conv_0(actvn(x)) + dx = self.conv_1(actvn(dx)) + out = x_s + 0.1*dx + + return out + + def _shortcut(self, x): + if self.learned_shortcut: + x_s = self.conv_s(x) + else: + x_s = x + return x_s + + +def actvn(x): + out = F.leaky_relu(x, 2e-1) + return out \ No newline at end of file diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/train.py b/pytorch_pretrained_gans/self_conditioned/gan_training/train.py new file mode 100644 index 0000000..c2f1a2d --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/train.py @@ -0,0 +1,152 @@ +# coding: utf-8 +import torch +from torch.nn import functional as F +import torch.utils.data +import torch.utils.data.distributed +from torch import autograd +import numpy as np + + +class Trainer(object): + def __init__(self, + generator, + discriminator, + g_optimizer, + d_optimizer, + gan_type, + reg_type, + reg_param): + + self.generator = generator + self.discriminator = discriminator + self.g_optimizer = g_optimizer + self.d_optimizer = d_optimizer + self.gan_type = gan_type + self.reg_type = reg_type + self.reg_param = reg_param + + print(f'D reg gamma: {self.reg_param}') + + def generator_trainstep(self, y, z): + assert (y.size(0) == z.size(0)) + toggle_grad(self.generator, True) + toggle_grad(self.discriminator, False) + + self.generator.train() + self.discriminator.train() + self.g_optimizer.zero_grad() + + x_fake = self.generator(z, y) + d_fake = self.discriminator(x_fake, y) + gloss = self.compute_loss(d_fake, 1) + gloss.backward() + + self.g_optimizer.step() + + return gloss.item() + + def discriminator_trainstep(self, x_real, y, z): + toggle_grad(self.generator, False) + toggle_grad(self.discriminator, True) + self.generator.train() + self.discriminator.train() + self.d_optimizer.zero_grad() + + # On real data + x_real.requires_grad_() + + d_real = self.discriminator(x_real, y) + dloss_real = self.compute_loss(d_real, 1) + + if self.reg_type == 'real' or self.reg_type == 'real_fake': + dloss_real.backward(retain_graph=True) + reg = self.reg_param * compute_grad2(d_real, x_real).mean() + reg.backward() + else: + dloss_real.backward() + + # On fake data + with torch.no_grad(): + x_fake = self.generator(z, y) + + x_fake.requires_grad_() + d_fake = self.discriminator(x_fake, y) + dloss_fake = self.compute_loss(d_fake, 0) + + if self.reg_type == 'fake' or self.reg_type == 'real_fake': + dloss_fake.backward(retain_graph=True) + reg = self.reg_param * compute_grad2(d_fake, x_fake).mean() + reg.backward() + else: + dloss_fake.backward() + + if self.reg_type == 'wgangp': + reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y) + reg.backward() + elif self.reg_type == 'wgangp0': + reg = self.reg_param * self.wgan_gp_reg( + x_real, x_fake, y, center=0.) + reg.backward() + + self.d_optimizer.step() + + dloss = (dloss_real + dloss_fake) + if self.reg_type == 'none': + reg = torch.tensor(0.) + + return dloss.item(), reg.item() + + def compute_loss(self, d_out, target): + targets = d_out.new_full(size=d_out.size(), fill_value=target) + + if self.gan_type == 'standard': + loss = F.binary_cross_entropy_with_logits(d_out, targets) + elif self.gan_type == 'wgan': + loss = (2 * target - 1) * d_out.mean() + else: + raise NotImplementedError + + return loss + + def wgan_gp_reg(self, x_real, x_fake, y, center=1.): + batch_size = y.size(0) + eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1) + x_interp = (1 - eps) * x_real + eps * x_fake + x_interp = x_interp.detach() + x_interp.requires_grad_() + d_out = self.discriminator(x_interp, y) + + reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean() + + return reg + + +# Utility functions +def toggle_grad(model, requires_grad): + for p in model.parameters(): + p.requires_grad_(requires_grad) + + +def compute_grad2(d_out, x_in): + batch_size = x_in.size(0) + grad_dout = autograd.grad(outputs=d_out.sum(), + inputs=x_in, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + grad_dout2 = grad_dout.pow(2) + assert (grad_dout2.size() == x_in.size()) + reg = grad_dout2.view(batch_size, -1).sum(1) + return reg + + +def update_average(model_tgt, model_src, beta): + toggle_grad(model_src, False) + toggle_grad(model_tgt, False) + + param_dict_src = dict(model_src.named_parameters()) + + for p_name, p_tgt in model_tgt.named_parameters(): + p_src = param_dict_src[p_name] + assert (p_src is not p_tgt) + p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) diff --git a/pytorch_pretrained_gans/self_conditioned/gan_training/utils.py b/pytorch_pretrained_gans/self_conditioned/gan_training/utils.py new file mode 100644 index 0000000..8009a92 --- /dev/null +++ b/pytorch_pretrained_gans/self_conditioned/gan_training/utils.py @@ -0,0 +1,52 @@ +import torch +import torch.utils.data +import torch.utils.data.distributed +import torchvision + +import os + + +def save_images(imgs, outfile, nrow=8): + imgs = imgs / 2 + 0.5 # unnormalize + torchvision.utils.save_image(imgs, outfile, nrow=nrow) + + +def get_nsamples(data_loader, N): + x = [] + y = [] + n = 0 + for x_next, y_next in data_loader: + x.append(x_next) + y.append(y_next) + n += x_next.size(0) + if n > N: + break + x = torch.cat(x, dim=0)[:N] + y = torch.cat(y, dim=0)[:N] + return x, y + + +def update_average(model_tgt, model_src, beta): + param_dict_src = dict(model_src.named_parameters()) + + for p_name, p_tgt in model_tgt.named_parameters(): + p_src = param_dict_src[p_name] + assert (p_src is not p_tgt) + p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) + + +def get_most_recent(d, ext): + if not os.path.exists(d): + print(f'Directory {d} does not exist') + return -1 + its = [] + for f in os.listdir(d): + try: + it = int(f.split(ext + "_")[1].split('.pt')[0]) + its.append(it) + except Exception as e: + pass + if len(its) == 0: + print('Found no files with extension \"%s\" under %s' % (ext, d)) + return -1 + return max(its) diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/__init__.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/__init__.py new file mode 100644 index 0000000..f130262 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/__init__.py @@ -0,0 +1,86 @@ +import os +import sys +import torch +from torch.hub import urlparse, get_dir, download_url_to_file +import pickle + + +MODELS = { + 'ffhq': ('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl', None), + 'afhqwild': ('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl', None), +} + + +def download_url(url, download_dir=None, filename=None): + parts = urlparse(url) + if download_dir is None: + hub_dir = get_dir() + download_dir = os.path.join(hub_dir, 'checkpoints') + if filename is None: + filename = os.path.basename(parts.path) + cached_file = os.path.join(download_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + download_url_to_file(url, cached_file) + return cached_file + + +class GeneratorWrapper(torch.nn.Module): + """ A wrapper to put the GAN in a standard format. This wrapper takes + w as input, rather than (z, c) """ + + def __init__(self, G, num_classes=None): + super().__init__() + self.G = G # NOTE! This takes in w, rather than z + self.dim_z = G.synthesis.w_dim + self.conditional = (num_classes is not None) + self.num_classes = num_classes + + self.num_ws = G.synthesis.num_ws + self.truncation_psi = 0.5 + self.truncation_cutoff = 8 + + def forward(self, z): + if len(z.shape) == 2: # expand to 18 layers + z = z.unsqueeze(1).repeat(1, self.num_ws, 1) + return self.G.synthesis(z) + + def sample_latent(self, batch_size, device='cuda'): + z = torch.randn([batch_size, self.dim_z], device=device) + c = None if self.conditional else None # not implemented for conditional models + w = self.G.mapping(z, c, truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff) + return w + + +def add_utils_to_path(): + import sys + from pathlib import Path + util_path = str(Path(__file__).parent) + if util_path not in sys.path: + sys.path.append(util_path) + print(f'Added {util_path} to path') + + +def make_stylegan2(model_name='ffhq') -> torch.nn.Module: + """G takes as input an image in NCHW format with dtype float32, normalized + to the range [-1, +1]. Some models also take a conditioning class label, + which is passed as img = G(z, c)""" + add_utils_to_path() # we need dnnlib and torch_utils in the path + url, num_classes = MODELS[model_name] + cached_file = download_url(url) + assert cached_file.endswith('.pkl') + with open(cached_file, 'rb') as f: + G = pickle.load(f)['G_ema'] + G = GeneratorWrapper(G, num_classes=num_classes) + return G.eval() + + +if __name__ == '__main__': + # Testing + G = make_stylegan2().cuda() + print('Created G') + print(f'Params: {sum(p.numel() for p in G.parameters()):_}') + z = torch.randn([1, G.dim_z]).cuda() + print(f'z.shape: {z.shape}') + x = G(z) + print(f'x.shape: {x.shape}') diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/__init__.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/__init__.py new file mode 100755 index 0000000..2f08cf3 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/util.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/util.py new file mode 100755 index 0000000..7672533 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/util.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/__init__.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/__init__.py new file mode 100755 index 0000000..ece0ea0 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/custom_ops.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/custom_ops.py new file mode 100755 index 0000000..4cc4e43 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/misc.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/misc.py new file mode 100755 index 0000000..7829f4d --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/misc.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/__init__.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/__init__.py new file mode 100755 index 0000000..ece0ea0 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp new file mode 100755 index 0000000..5d2425d --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cu b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cu new file mode 100755 index 0000000..dd8fc47 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.h b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.h new file mode 100755 index 0000000..a32187e --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.py new file mode 100755 index 0000000..4bcb409 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.py @@ -0,0 +1,212 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import dnnlib +import traceback + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_gradfix.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_gradfix.py new file mode 100755 index 0000000..e95e10d --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_resample.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_resample.py new file mode 100755 index 0000000..cd47507 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/fma.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/fma.py new file mode 100755 index 0000000..2eeac58 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/fma.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py new file mode 100755 index 0000000..ca6b341 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import warnings +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp new file mode 100755 index 0000000..2d7177f --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cu b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cu new file mode 100755 index 0000000..ebdd987 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h new file mode 100755 index 0000000..c9e2032 --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.py new file mode 100755 index 0000000..ceeac2b --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/persistence.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/persistence.py new file mode 100755 index 0000000..0186cfd --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/persistence.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/training_stats.py b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/training_stats.py new file mode 100755 index 0000000..26f467f --- /dev/null +++ b/pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/training_stats.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7eb8f63 --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python + +from setuptools import setup, find_packages + +setup( + name='pytorch_pretrained_gans', + version='0.0.1', + description='Project', + author='Luke Melas-Kyriazi', + author_email='', + url='https://github.com/lukemelas/', + install_requires=[], + packages=find_packages(), +)