diff --git a/train.py b/train.py index 5b36d792..6baa31d3 100755 --- a/train.py +++ b/train.py @@ -44,8 +44,10 @@ def setup_training_options( # Base config. cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline' + cifar_tuning = None # Enforce CIFAR-specific architecture tuning: , default = False gamma = None, # Override R1 gamma: , default = depends on cfg kimg = None, # Override training duration: , default = depends on cfg + cfg_map = None, # Override config map: , default = depends on cfg # Discriminator augmentation. aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv' @@ -161,6 +163,7 @@ def setup_training_options( cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res' + 'auto_no_ramp': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=None, map=2), 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, unlike original StyleGAN2 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), @@ -171,7 +174,7 @@ def setup_training_options( assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) - if cfg == 'auto': + if cfg.startswith('auto'): desc += f'{gpus:d}' spec.ref_gpus = gpus spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay @@ -195,7 +198,14 @@ def setup_training_options( args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow - if cfg == 'cifar': + if cifar_tuning is None: + cifar_tuning = False + else: + assert isinstance(cifar_tuning, bool) + if cifar_tuning: + desc += '-tuning' + + if cifar_tuning or cfg == 'cifar': args.loss_args.pl_weight = 0 # disable path length regularization args.G_args.style_mixing_prob = None # disable style mixing args.D_args.architecture = 'orig' # disable residual skip connections @@ -214,6 +224,12 @@ def setup_training_options( desc += f'-kimg{kimg:d}' args.total_kimg = kimg + if cfg_map is not None: + assert isinstance(cfg_map, int) + if not cfg_map >= 1: + raise UserError('--cfg_map must be at least 1') + args.G_args.mapping_layers = cfg_map + # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- @@ -246,9 +262,9 @@ def setup_training_options( if p is not None: assert isinstance(p, float) - if aug != 'fixed': - raise UserError('--p can only be specified with --aug=fixed') - if not 0 <= p <= 1: + if resume != 'latest' and aug != 'fixed': + raise UserError('--p can only be specified with --resume=latest or --aug=fixed') + if resume != 'latest' and not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_args.initial_strength = p @@ -530,9 +546,11 @@ def main(): group.add_argument('--metricdata', help='Dataset to evaluate metrics against (optional)', metavar='PATH') group = parser.add_argument_group('base config') - group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline']) + group.add_argument('--cfg', help='Base config (default: auto)', choices=['auto', 'auto_no_ramp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline']) + group.add_argument('--cifar_tuning', help='Enforce CIFAR-specific architecture tuning (default: false)', type=_str_to_bool, metavar='BOOL') group.add_argument('--gamma', help='Override R1 gamma', type=float, metavar='FLOAT') group.add_argument('--kimg', help='Override training duration', type=int, metavar='INT') + group.add_argument('--cfg_map', help='Override config map', type=int, metavar='INT') group = parser.add_argument_group('discriminator augmentation') group.add_argument('--aug', help='Augmentation mode (default: ada)', choices=['noaug', 'ada', 'fixed', 'adarv']) diff --git a/training/misc.py b/training/misc.py new file mode 100644 index 00000000..46617236 --- /dev/null +++ b/training/misc.py @@ -0,0 +1,34 @@ +import glob +import os +import re + +from pathlib import Path + +def get_parent_dir(run_dir): + out_dir = Path(run_dir).parent + + return out_dir + +def locate_latest_pkl(out_dir): + all_pickle_names = sorted(glob.glob(os.path.join(out_dir, '0*', 'network-*.pkl'))) + + try: + latest_pickle_name = all_pickle_names[-1] + except IndexError: + latest_pickle_name = None + + return latest_pickle_name + +def parse_kimg_from_network_name(network_pickle_name): + + if network_pickle_name is not None: + resume_run_id = os.path.basename(os.path.dirname(network_pickle_name)) + RE_KIMG = re.compile('network-snapshot-(\d+).pkl') + try: + kimg = int(RE_KIMG.match(os.path.basename(network_pickle_name)).group(1)) + except AttributeError: + kimg = 0.0 + else: + kimg = 0.0 + + return float(kimg) diff --git a/training/training_loop.py b/training/training_loop.py index f70c11f8..d9f90308 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -19,6 +19,7 @@ from dnnlib.tflib.autosummary import autosummary from training import dataset +from training import misc #---------------------------------------------------------------------------- # Select size and contents of the image snapshot grids that are exported @@ -121,6 +122,15 @@ def training_loop( G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') + + if resume_pkl == 'latest': + out_dir = misc.get_parent_dir(run_dir) + resume_pkl = misc.locate_latest_pkl(out_dir) + + resume_kimg = misc.parse_kimg_from_network_name(resume_pkl) + if resume_kimg > 0: + print(f'Resuming from kimg = {resume_kimg}') + if resume_pkl is not None: print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: @@ -133,10 +143,10 @@ def training_loop( print('Exporting sample images...') grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(training_set) - save_image_grid(grid_reals, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + save_image_grid(grid_reals, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) - save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) print(f'Replicating networks across {num_gpus} GPUs...') G_gpus = [G] @@ -217,10 +227,10 @@ def training_loop( print(f'Training for {total_kimg} kimg...') print() if progress_fn is not None: - progress_fn(0, total_kimg) + progress_fn(int(resume_kimg), total_kimg) tick_start_time = time.time() maintenance_time = tick_start_time - start_time - cur_nimg = 0 + cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 @@ -301,7 +311,7 @@ def training_loop( # Save snapshots. if image_snapshot_ticks is not None and (done or cur_tick % image_snapshot_ticks == 0): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) - save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) if network_snapshot_ticks is not None and (done or cur_tick % network_snapshot_ticks == 0): pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl') with open(pkl, 'wb') as f: