Skip to content

Commit

Permalink
Add more training arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
mobiuscreek committed May 3, 2020
1 parent e3e2900 commit c8c9db1
Showing 1 changed file with 218 additions and 0 deletions.
218 changes: 218 additions & 0 deletions train_stylegan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html

import argparse
import copy
import os
import sys

import dnnlib
from dnnlib import EasyDict

from metrics.metric_defaults import metric_defaults

#----------------------------------------------------------------------------

_valid_configs = [
# Table 1
'config-a', # Baseline StyleGAN
'config-b', # + Weight demodulation
'config-c', # + Lazy regularization
'config-d', # + Path length regularization
'config-e', # + No growing, new G & D arch.
'config-f', # + Large networks (default)

# Table 2
'config-e-Gorig-Dorig', 'config-e-Gorig-Dresnet', 'config-e-Gorig-Dskip',
'config-e-Gresnet-Dorig', 'config-e-Gresnet-Dresnet', 'config-e-Gresnet-Dskip',
'config-e-Gskip-Dorig', 'config-e-Gskip-Dresnet', 'config-e-Gskip-Dskip',
]

#----------------------------------------------------------------------------

def run(dataset,
data_dir,
result_dir,
config_id,
num_gpus,
total_kimg,
gamma,
mirror_augment,
metrics,
resume_pkl,
resume_kimg,
resume_time,
resume_with_new_nets):

train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network.
D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network.
G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg') # Options for generator loss.
D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss.
sched = EasyDict() # Options for TrainingSchedule.
grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid().
sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf().

train.data_dir = data_dir
train.total_kimg = total_kimg
train.mirror_augment = mirror_augment
train.image_snapshot_ticks = train.network_snapshot_ticks = 10
train.resume_pkl = resume_pkl
train.resume_kimg = resume_kimg
train.resume_time = resume_time
train.resume_with_new_nets = resume_with_new_nets
sched.G_lrate_base = sched.D_lrate_base = 0.002
sched.minibatch_size_base = 32
sched.minibatch_gpu_base = 4
D_loss.gamma = 10
metrics = [metric_defaults[x] for x in metrics]
desc = 'stylegan2'

desc += '-' + dataset
dataset_args = EasyDict(tfrecord_dir=dataset)

assert num_gpus in [1, 2, 4, 8]
sc.num_gpus = num_gpus
desc += '-%dgpu' % num_gpus

assert config_id in _valid_configs
desc += '-' + config_id

# Configs A-E: Shrink networks to match original StyleGAN.
if config_id != 'config-f':
G.fmap_base = D.fmap_base = 8 << 10

# Config E: Set gamma to 100 and override G & D architecture.
if config_id.startswith('config-e'):
D_loss.gamma = 100
if 'Gorig' in config_id: G.architecture = 'orig'
if 'Gskip' in config_id: G.architecture = 'skip' # (default)
if 'Gresnet' in config_id: G.architecture = 'resnet'
if 'Dorig' in config_id: D.architecture = 'orig'
if 'Dskip' in config_id: D.architecture = 'skip'
if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)

# Configs A-D: Enable progressive growing and switch to networks that support it.
if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
sched.lod_initial_resolution = 8
sched.G_lrate_base = sched.D_lrate_base = 0.001
sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
sched.minibatch_size_base = 32 # (default)
sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
sched.minibatch_gpu_base = 4 # (default)
sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
G.synthesis_func = 'G_synthesis_stylegan_revised'
D.func_name = 'training.networks_stylegan2.D_stylegan'

# Configs A-C: Disable path length regularization.
if config_id in ['config-a', 'config-b', 'config-c']:
G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

# Configs A-B: Disable lazy regularization.
if config_id in ['config-a', 'config-b']:
train.lazy_regularization = False

# Config A: Switch to original StyleGAN networks.
if config_id == 'config-a':
G = EasyDict(func_name='training.networks_stylegan.G_style')
D = EasyDict(func_name='training.networks_stylegan.D_basic')

if gamma is not None:
D_loss.gamma = gamma

sc.submit_target = dnnlib.SubmitTarget.LOCAL
sc.local.do_not_copy_source_files = True
kwargs = EasyDict(train)
kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
kwargs.submit_config = copy.deepcopy(sc)
kwargs.submit_config.run_dir_root = result_dir
kwargs.submit_config.run_desc = desc
dnnlib.submit_run(**kwargs)

#----------------------------------------------------------------------------

def _str_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

def _parse_comma_sep(s):
if s is None or s.lower() == 'none' or s == '':
return []
return s.split(',')

#----------------------------------------------------------------------------

_examples = '''examples:
# Train StyleGAN2 using the FFHQ dataset
python %(prog)s --num-gpus=8 --data-dir=~/datasets --config=config-f --dataset=ffhq --mirror-augment=true
valid configs:
''' + ', '.join(_valid_configs) + '''
valid metrics:
''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + '''
'''

def main():
parser = argparse.ArgumentParser(
description='Train StyleGAN2.',
epilog=_examples,
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
parser.add_argument('--data-dir', help='Dataset root directory', required=True)
parser.add_argument('--dataset', help='Training dataset', required=True)
parser.add_argument('--config', help='Training config (default: %(default)s)', default='config-f', required=True, dest='config_id', metavar='CONFIG')
parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N')
parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int)
parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float)
parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep)
parser.add_argument('--resume-pkl', help='Network pickle to resume training from', default=None)
parser.add_argument('--resume-kimg', help='Assumed training progress at the beginning. Affects reporting and training schedule.', default=0.0, type=float)
parser.add_argument('--resume-time', help='Assumed wallclock time at the beginning. Affects reporting.', default=0.0, type=float)
parser.add_argument('--resume-with-new-nets', help='Construct new networks according to G_args and D_args before resuming training (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)


args = parser.parse_args()

if not os.path.exists(args.data_dir):
print ('Error: dataset root directory does not exist.')
sys.exit(1)

if args.config_id not in _valid_configs:
print ('Error: --config value must be one of: ', ', '.join(_valid_configs))
sys.exit(1)

for metric in args.metrics:
if metric not in metric_defaults:
print ('Error: unknown metric \'%s\'' % metric)
sys.exit(1)

run(**vars(args))

#----------------------------------------------------------------------------

if __name__ == "__main__":
main()

#----------------------------------------------------------------------------


0 comments on commit c8c9db1

Please sign in to comment.