Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert repo into Python package for pip install #177

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
## StyleGAN2-ADA — `pip install` version of Official PyTorch implementation

I have modified the official PyTorch implementation so that you can `pip install` this repository as a dependency and reuse the classes and functions here.

### Requirements

* Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
* 1–8 high-end NVIDIA GPUs with at least 12 GB of memory. We have done all testing and development using NVIDIA DGX-1 with 8 Tesla V100 GPUs.
* 64-bit Python 3.7 and PyTorch 1.7.1. See [https://pytorch.org/](https://pytorch.org/) for PyTorch install instructions.
* CUDA toolkit 11.0 or later. Use at least version 11.1 if running on RTX 3090. (Why is a separate CUDA toolkit installation required? See comments in [#2](https://github.com/NVlabs/stylegan2-ada-pytorch/issues/2#issuecomment-779457121).)

### Installation

From repo's root directory `stylegan2-ada-pytorch`, run `python -m pip install .`

### Original official implementation

Available [here](https://github.com/NVlabs/stylegan2-ada-pytorch), the original `README.md` is copied below.

***

## StyleGAN2-ADA — Official PyTorch implementation

![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
Expand Down Expand Up @@ -151,7 +172,7 @@ w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
img = G.synthesis(w, noise_mode='const', force_fp32=True)
```

Please refer to [`generate.py`](./generate.py), [`style_mixing.py`](./style_mixing.py), and [`projector.py`](./projector.py) for further examples.
Please refer to [`generate.py`](stylegan2_ada_pytorch/generate.py), [`style_mixing.py`](stylegan2_ada_pytorch/style_mixing.py), and [`projector.py`](stylegan2_ada_pytorch/projector.py) for further examples.

## Preparing datasets

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
30 changes: 30 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[metadata]
name = stylegan2-ada-pytorch
version = 1.0.0
description = StyleGAN2-ADA - Official PyTorch implementation
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/NVlabs/stylegan2-ada-pytorch
project_urls =
Bug Tracker = https://github.com/NVlabs/stylegan2-ada-pytorch/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent

[options]
package_dir =
= .
packages = find:
python_requires = >=3.6
install_requires =
torch >=1.7.0
click
requests
tqdm
pyspng
ninja
imageio-ffmpeg ==0.4.3

[options.packages.find]
where = .
Empty file.
14 changes: 6 additions & 8 deletions calc_metrics.py → stylegan2_ada_pytorch/calc_metrics.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
import tempfile
import copy
import torch
import dnnlib

import legacy
from metrics import metric_main
from metrics import metric_utils
from torch_utils import training_stats
from torch_utils import custom_ops
from torch_utils import misc
from stylegan2_ada_pytorch import legacy, dnnlib
from stylegan2_ada_pytorch.metrics import metric_main, metric_utils
from stylegan2_ada_pytorch.torch_utils import training_stats
from stylegan2_ada_pytorch.torch_utils import custom_ops, misc


#----------------------------------------------------------------------------

Expand Down Expand Up @@ -61,7 +59,7 @@ def subprocess_fn(rank, args, temp_dir):
print(f'Calculating {metric}...')
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
if rank == 0:
metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
if rank == 0 and args.verbose:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions generate.py → stylegan2_ada_pytorch/generate.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from typing import List, Optional

import click
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy
from stylegan2_ada_pytorch import legacy, dnnlib


#----------------------------------------------------------------------------

Expand Down
9 changes: 5 additions & 4 deletions legacy.py → stylegan2_ada_pytorch/legacy.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import copy
import numpy as np
import torch
import dnnlib
from torch_utils import misc
from stylegan2_ada_pytorch import dnnlib
from stylegan2_ada_pytorch.torch_utils import misc


#----------------------------------------------------------------------------

Expand Down Expand Up @@ -165,7 +166,7 @@ def kwarg(tf_name, default=None, none=None):
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

# Convert params.
from training import networks
from stylegan2_ada_pytorch.training import networks
G = networks.Generator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
_populate_module_params(G,
Expand Down Expand Up @@ -262,7 +263,7 @@ def kwarg(tf_name, default=None):
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

# Convert params.
from training import networks
from stylegan2_ada_pytorch.training import networks
D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
_populate_module_params(D,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion metrics/metric_main.py → stylegan2_ada_pytorch/metrics/metric_main.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import json
import torch
import dnnlib
from .. import dnnlib

from . import metric_utils
from . import frechet_inception_distance
Expand Down
5 changes: 3 additions & 2 deletions metrics/metric_utils.py → ...egan2_ada_pytorch/metrics/metric_utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import uuid
import numpy as np
import torch
import dnnlib
from stylegan2_ada_pytorch import dnnlib


#----------------------------------------------------------------------------

Expand Down Expand Up @@ -156,7 +157,7 @@ def update(self, cur_items):
total_time = cur_time - self.start_time
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
if (self.verbose) and (self.tag is not None):
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item * 1e3:.2f}')
self.batch_time = cur_time
self.batch_items = cur_items

Expand Down
2 changes: 1 addition & 1 deletion metrics/perceptual_path_length.py → ...pytorch/metrics/perceptual_path_length.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import copy
import numpy as np
import torch
import dnnlib
from .. import dnnlib
from . import metric_utils

#----------------------------------------------------------------------------
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions projector.py → stylegan2_ada_pytorch/projector.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch
import torch.nn.functional as F

import dnnlib
import legacy
from stylegan2_ada_pytorch import legacy, dnnlib


def project(
G,
Expand Down
4 changes: 2 additions & 2 deletions style_mixing.py → stylegan2_ada_pytorch/style_mixing.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from typing import List

import click
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy
from stylegan2_ada_pytorch import legacy, dnnlib


#----------------------------------------------------------------------------

Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion torch_utils/misc.py → stylegan2_ada_pytorch/torch_utils/misc.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch
import warnings
import dnnlib
from stylegan2_ada_pytorch import dnnlib

#----------------------------------------------------------------------------
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
20 changes: 10 additions & 10 deletions torch_utils/ops/bias_act.py → ...2_ada_pytorch/torch_utils/ops/bias_act.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import warnings
import numpy as np
import torch
import dnnlib
from stylegan2_ada_pytorch import dnnlib
import traceback

from .. import custom_ops
Expand All @@ -21,15 +21,15 @@
#----------------------------------------------------------------------------

activation_funcs = {
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
}

#----------------------------------------------------------------------------
Expand Down
File renamed without changes.
10 changes: 5 additions & 5 deletions torch_utils/ops/conv2d_resample.py → ...ytorch/torch_utils/ops/conv2d_resample.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight

# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1):
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x

# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
return x

# Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
return x

Expand All @@ -136,7 +136,7 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0)
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, flip_filter=flip_filter)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
Expand All @@ -147,7 +147,7 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)

# Fallback: Generic reference implementation.
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1:
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion torch_utils/persistence.py → ...n2_ada_pytorch/torch_utils/persistence.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import copy
import uuid
import types
import dnnlib
from stylegan2_ada_pytorch import dnnlib

#----------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion torch_utils/training_stats.py → ...ada_pytorch/torch_utils/training_stats.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import numpy as np
import torch
import dnnlib
from .. import dnnlib

from . import misc

Expand Down
15 changes: 8 additions & 7 deletions train.py → stylegan2_ada_pytorch/train.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import json
import tempfile
import torch
import dnnlib
from stylegan2_ada_pytorch import dnnlib

from stylegan2_ada_pytorch.training import training_loop
from stylegan2_ada_pytorch.metrics import metric_main
from stylegan2_ada_pytorch.torch_utils import training_stats
from stylegan2_ada_pytorch.torch_utils import custom_ops

from training import training_loop
from metrics import metric_main
from torch_utils import training_stats
from torch_utils import custom_ops

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -182,8 +183,8 @@ def setup_training_loop_kwargs(
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8)
args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8)
args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

args.total_kimg = spec.kimg
Expand Down
File renamed without changes.
16 changes: 7 additions & 9 deletions training/augment.py → stylegan2_ada_pytorch/training/augment.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import numpy as np
import scipy.signal
import torch
from torch_utils import persistence
from torch_utils import misc
from torch_utils.ops import upfirdn2d
from torch_utils.ops import grid_sample_gradfix
from torch_utils.ops import conv2d_gradfix
from stylegan2_ada_pytorch.torch_utils import persistence, misc
from stylegan2_ada_pytorch.torch_utils.ops import grid_sample_gradfix, upfirdn2d
from stylegan2_ada_pytorch.torch_utils.ops import conv2d_gradfix

#----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.
Expand Down Expand Up @@ -279,7 +277,7 @@ def forward(self, images, debug_percentile=None):
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
margin = margin.max(misc.constant([0, 0] * 2, device=device))
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
margin = margin.min(misc.constant([width - 1, height - 1] * 2, device=device))
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)

# Pad image and adjust origin.
Expand All @@ -298,7 +296,7 @@ def forward(self, images, debug_percentile=None):
images = grid_sample_gradfix.grid_sample(images, grid)

# Downsample and crop.
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad * 2, flip_filter=True)

# --------------------------------------------
# Select parameters for color transformations.
Expand Down Expand Up @@ -395,8 +393,8 @@ def forward(self, images, debug_percentile=None):
p = self.Hz_fbank.shape[1] // 2
images = images.reshape([1, batch_size * num_channels, height, width])
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size * num_channels)
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size * num_channels)
images = images.reshape([batch_size, num_channels, height, width])

# ------------------------
Expand Down
2 changes: 1 addition & 1 deletion training/dataset.py → stylegan2_ada_pytorch/training/dataset.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import PIL.Image
import json
import torch
import dnnlib
from stylegan2_ada_pytorch import dnnlib

try:
import pyspng
Expand Down
Loading