diff --git a/run_training.py b/run_training.py index bc4c0a2b..eb03e97c 100755 --- a/run_training.py +++ b/run_training.py @@ -33,7 +33,7 @@ #---------------------------------------------------------------------------- -def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics): +def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics, resume_pkl=None): train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network. @@ -50,6 +50,7 @@ def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, m train.total_kimg = total_kimg train.mirror_augment = mirror_augment train.image_snapshot_ticks = train.network_snapshot_ticks = 10 + train.resume_pkl = resume_pkl sched.G_lrate_base = sched.D_lrate_base = 0.002 sched.minibatch_size_base = 32 sched.minibatch_gpu_base = 4 @@ -168,6 +169,7 @@ def main(): parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float) parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool) parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep) + parser.add_argument('--resume-pkl', help="Resume from pkl file", default=None) args = parser.parse_args()