From cc55d3796ea83f8a7d6fcf7f14166724bccd7c7e Mon Sep 17 00:00:00 2001 From: Minematas Date: Tue, 28 Apr 2020 16:45:13 +0300 Subject: [PATCH] Added resume training args --- run_training.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/run_training.py b/run_training.py index bc4c0a2b..c6559434 100755 --- a/run_training.py +++ b/run_training.py @@ -33,7 +33,19 @@ #---------------------------------------------------------------------------- -def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics): +def run(dataset, + data_dir, + result_dir, + config_id, + num_gpus, + total_kimg, + gamma, + mirror_augment, + metrics, + resume_pkl, + resume_kimg, + resume_time, + resume_with_new_nets): train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network. @@ -50,6 +62,10 @@ def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, m train.total_kimg = total_kimg train.mirror_augment = mirror_augment train.image_snapshot_ticks = train.network_snapshot_ticks = 10 + train.resume_pkl = resume_pkl + train.resume_kimg = resume_kimg + train.resume_time = resume_time + train.resume_with_new_nets = resume_with_new_nets sched.G_lrate_base = sched.D_lrate_base = 0.002 sched.minibatch_size_base = 32 sched.minibatch_gpu_base = 4 @@ -168,6 +184,10 @@ def main(): parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float) parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool) parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep) + parser.add_argument('--resume-pkl', help='Network pickle to resume training from', default=None) + parser.add_argument('--resume-kimg', help='Assumed training progress at the beginning. Affects reporting and training schedule.', default=0.0, type=float) + parser.add_argument('--resume-time', help='Assumed wallclock time at the beginning. Affects reporting.', default=0.0, type=float) + parser.add_argument('--resume-with-new-nets', help='Construct new networks according to G_args and D_args before resuming training (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool) args = parser.parse_args()