Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume from the latest pickle #6

Closed
wants to merge 8 commits into from
30 changes: 24 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def setup_training_options(

# Base config.
cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'
cifar_tuning = None # Enforce CIFAR-specific architecture tuning: <bool>, default = False
gamma = None, # Override R1 gamma: <float>, default = depends on cfg
kimg = None, # Override training duration: <int>, default = depends on cfg
cfg_map = None, # Override config map: <int>, default = depends on cfg

# Discriminator augmentation.
aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv'
Expand Down Expand Up @@ -161,6 +163,7 @@ def setup_training_options(

cfg_specs = {
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res'
'auto_no_ramp': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=None, map=2),
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, unlike original StyleGAN2
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
Expand All @@ -171,7 +174,7 @@ def setup_training_options(

assert cfg in cfg_specs
spec = dnnlib.EasyDict(cfg_specs[cfg])
if cfg == 'auto':
if cfg.startswith('auto'):
desc += f'{gpus:d}'
spec.ref_gpus = gpus
spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
Expand All @@ -195,7 +198,14 @@ def setup_training_options(
args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training
args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow

if cfg == 'cifar':
if cifar_tuning is None:
cifar_tuning = False
else:
assert isinstance(cifar_tuning, bool)
if cifar_tuning:
desc += '-tuning'

if cifar_tuning or cfg == 'cifar':
args.loss_args.pl_weight = 0 # disable path length regularization
args.G_args.style_mixing_prob = None # disable style mixing
args.D_args.architecture = 'orig' # disable residual skip connections
Expand All @@ -214,6 +224,12 @@ def setup_training_options(
desc += f'-kimg{kimg:d}'
args.total_kimg = kimg

if cfg_map is not None:
assert isinstance(cfg_map, int)
if not cfg_map >= 1:
raise UserError('--cfg_map must be at least 1')
args.G_args.mapping_layers = cfg_map

# ---------------------------------------------------
# Discriminator augmentation: aug, p, target, augpipe
# ---------------------------------------------------
Expand Down Expand Up @@ -246,9 +262,9 @@ def setup_training_options(

if p is not None:
assert isinstance(p, float)
if aug != 'fixed':
raise UserError('--p can only be specified with --aug=fixed')
if not 0 <= p <= 1:
if resume != 'latest' and aug != 'fixed':
raise UserError('--p can only be specified with --resume=latest or --aug=fixed')
if resume != 'latest' and not 0 <= p <= 1:
raise UserError('--p must be between 0 and 1')
desc += f'-p{p:g}'
args.augment_args.initial_strength = p
Expand Down Expand Up @@ -530,9 +546,11 @@ def main():
group.add_argument('--metricdata', help='Dataset to evaluate metrics against (optional)', metavar='PATH')

group = parser.add_argument_group('base config')
group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'])
group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', 'auto_no_ramp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'])
group.add_argument('--cifar_tuning', help='Enforce CIFAR-specific architecture tuning (default: false)', type=_str_to_bool, metavar='BOOL')
group.add_argument('--gamma', help='Override R1 gamma', type=float, metavar='FLOAT')
group.add_argument('--kimg', help='Override training duration', type=int, metavar='INT')
group.add_argument('--cfg_map', help='Override config map', type=int, metavar='INT')

group = parser.add_argument_group('discriminator augmentation')
group.add_argument('--aug', help='Augmentation mode (default: ada)', choices=['noaug', 'ada', 'fixed', 'adarv'])
Expand Down
34 changes: 34 additions & 0 deletions training/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import glob
import os
import re

from pathlib import Path

def get_parent_dir(run_dir):
out_dir = Path(run_dir).parent

return out_dir

def locate_latest_pkl(out_dir):
all_pickle_names = sorted(glob.glob(os.path.join(out_dir, '0*', 'network-*.pkl')))

try:
latest_pickle_name = all_pickle_names[-1]
except IndexError:
latest_pickle_name = None

return latest_pickle_name

def parse_kimg_from_network_name(network_pickle_name):

if network_pickle_name is not None:
resume_run_id = os.path.basename(os.path.dirname(network_pickle_name))
RE_KIMG = re.compile('network-snapshot-(\d+).pkl')
try:
kimg = int(RE_KIMG.match(os.path.basename(network_pickle_name)).group(1))
except AttributeError:
kimg = 0.0
else:
kimg = 0.0

return float(kimg)
20 changes: 15 additions & 5 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dnnlib.tflib.autosummary import autosummary

from training import dataset
from training import misc

#----------------------------------------------------------------------------
# Select size and contents of the image snapshot grids that are exported
Expand Down Expand Up @@ -121,6 +122,15 @@ def training_loop(
G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
Gs = G.clone('Gs')

if resume_pkl == 'latest':
out_dir = misc.get_parent_dir(run_dir)
resume_pkl = misc.locate_latest_pkl(out_dir)

resume_kimg = misc.parse_kimg_from_network_name(resume_pkl)
if resume_kimg > 0:
print(f'Resuming from kimg = {resume_kimg}')

if resume_pkl is not None:
print(f'Resuming from "{resume_pkl}"')
with dnnlib.util.open_url(resume_pkl) as f:
Expand All @@ -133,10 +143,10 @@ def training_loop(

print('Exporting sample images...')
grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(training_set)
save_image_grid(grid_reals, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
save_image_grid(grid_reals, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size)
grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu)
save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size)

print(f'Replicating networks across {num_gpus} GPUs...')
G_gpus = [G]
Expand Down Expand Up @@ -217,10 +227,10 @@ def training_loop(
print(f'Training for {total_kimg} kimg...')
print()
if progress_fn is not None:
progress_fn(0, total_kimg)
progress_fn(int(resume_kimg), total_kimg)
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
cur_nimg = 0
cur_nimg = int(resume_kimg * 1000)
cur_tick = -1
tick_start_nimg = cur_nimg
running_mb_counter = 0
Expand Down Expand Up @@ -301,7 +311,7 @@ def training_loop(
# Save snapshots.
if image_snapshot_ticks is not None and (done or cur_tick % image_snapshot_ticks == 0):
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu)
save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size)
if network_snapshot_ticks is not None and (done or cur_tick % network_snapshot_ticks == 0):
pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl')
with open(pkl, 'wb') as f:
Expand Down