From ee70266fa5c3df6565bffc97842539963ebc8a6b Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 12 Jan 2021 21:01:25 +0000 Subject: [PATCH 1/3] Add W&B tracking for training pipeline --- dnnlib/tflib/autosummary.py | 5 +++-- dnnlib/wandb_utils.py | 30 ++++++++++++++++++++++++++++++ training/training_loop.py | 17 +++++++++++++---- 3 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 dnnlib/wandb_utils.py diff --git a/dnnlib/tflib/autosummary.py b/dnnlib/tflib/autosummary.py index 6b0d80b3..deeb586b 100755 --- a/dnnlib/tflib/autosummary.py +++ b/dnnlib/tflib/autosummary.py @@ -187,5 +187,6 @@ def save_summaries(file_writer, global_step=None): file_writer.add_summary(layout) with tf.device(None), tf.control_dependencies(None): _merge_op = tf.summary.merge_all() - - file_writer.add_summary(_merge_op.eval(), global_step) + _merge_op_eval = _merge_op.eval() + file_writer.add_summary(_merge_op_eval, global_step) + return _merge_op_eval diff --git a/dnnlib/wandb_utils.py b/dnnlib/wandb_utils.py new file mode 100644 index 00000000..ea642047 --- /dev/null +++ b/dnnlib/wandb_utils.py @@ -0,0 +1,30 @@ +import wandb +from wandb.integration.tensorboard import tf_summary_to_dict + +class WandbLogger(): + def __init__(self, project, name, config, group=None): + self.run = wandb.init(project=project, name=name, config=config, group=group) if not wandb.run else wandb.run + self.log_dict = {} + + def log_scalar(self, log_dict): + for key, value in log_dict.items(): + self.log_dict[key] = value + + def log_tf_summary(self, summary): + tf_log_dict = wandb.integration.tensorboard.tf_summary_to_dict(summary) + if tf_log_dict is None: + return + for key, value in tf_log_dict.items(): + self.log_dict[key] = value + + def log_image(self, path, name): + self.log_dict[name] = wandb.Image(path) + + def log_model_artifact(self, path, step): + model_artifact = wandb.Artifact('run_'+wandb.run.id+'_checkpoints', type='model', metadata={'cur_nimg': step}) + model_artifact.add_file(path, name='network-snapshot-%06d.pkl' % (step)) + wandb.log_artifact(model_artifact) + + def flush(self): + wandb.log(self.log_dict) + self.log_dict = {} \ No newline at end of file diff --git a/training/training_loop.py b/training/training_loop.py index c2d88cf0..18aee5b1 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -15,6 +15,7 @@ from training import dataset from training import misc from metrics import metric_base +from dnnlib.wandb_utils import WandbLogger #---------------------------------------------------------------------------- # Just-in-time processing of training images before feeding them to the networks. @@ -133,15 +134,18 @@ def training_loop( resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? - # Initialize dnnlib and TensorFlow. + # Initialize dnnlib, TensorFlow and WandbLogger tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus + wandb_logger = WandbLogger(project='stylegan2', name='train-'+dataset_args['tfrecord_dir'], + config=dnnlib.submit_config,group="training") # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) - + wandb_logger.log_image(dnnlib.make_run_dir_path('reals.png'), 'Real Samples') + # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: @@ -161,7 +165,8 @@ def training_loop( grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) - + wandb_logger.log_image(dnnlib.make_run_dir_path('fakes_init.png'), 'Initial Fakes') + # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): @@ -335,14 +340,18 @@ def training_loop( if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) + wandb_logger.log_image(dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), 'fake') if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) + wandb_logger.log_model_artifact(pkl, cur_nimg // 1000) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() - tflib.autosummary.save_summaries(summary_log, cur_nimg) + summary = tflib.autosummary.save_summaries(summary_log, cur_nimg) + wandb_logger.log_tf_summary(summary) + wandb_logger.flush() dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time From 1fc8b1ce886350b58980f31b8f98d677b5188e82 Mon Sep 17 00:00:00 2001 From: AYush Chaurasia Date: Tue, 19 Jan 2021 16:25:25 +0000 Subject: [PATCH 2/3] Add W&B for generators and projects --- dnnlib/wandb_utils.py | 36 +++++++++++++++++++++--------------- run_generator.py | 16 +++++++++------- run_projector.py | 9 +++++++++ training/training_loop.py | 8 ++++---- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/dnnlib/wandb_utils.py b/dnnlib/wandb_utils.py index ea642047..503bd8ef 100644 --- a/dnnlib/wandb_utils.py +++ b/dnnlib/wandb_utils.py @@ -1,25 +1,19 @@ +import imghdr + import wandb from wandb.integration.tensorboard import tf_summary_to_dict class WandbLogger(): - def __init__(self, project, name, config, group=None): - self.run = wandb.init(project=project, name=name, config=config, group=group) if not wandb.run else wandb.run + def __init__(self, project, name, config, job_type=None): + self.run = wandb.init(project=project, name=name, config=config, job_type=job_type) if not wandb.run else wandb.run self.log_dict = {} - - def log_scalar(self, log_dict): - for key, value in log_dict.items(): - self.log_dict[key] = value def log_tf_summary(self, summary): tf_log_dict = wandb.integration.tensorboard.tf_summary_to_dict(summary) - if tf_log_dict is None: - return - for key, value in tf_log_dict.items(): - self.log_dict[key] = value - - def log_image(self, path, name): - self.log_dict[name] = wandb.Image(path) - + if tf_log_dict: + for key, value in tf_log_dict.items(): + self.log_dict[key] = value + def log_model_artifact(self, path, step): model_artifact = wandb.Artifact('run_'+wandb.run.id+'_checkpoints', type='model', metadata={'cur_nimg': step}) model_artifact.add_file(path, name='network-snapshot-%06d.pkl' % (step)) @@ -27,4 +21,16 @@ def log_model_artifact(self, path, step): def flush(self): wandb.log(self.log_dict) - self.log_dict = {} \ No newline at end of file + self.log_dict = {} + + def log(self, log_dict, flush=False): + for key, value in log_dict.items(): + if imghdr.what(value): # Check if the value is an image + self.log_dict[key] = wandb.Image(value) + else: + self.log_dict[key] = value + if flush: + self.flush() + + + diff --git a/run_generator.py b/run_generator.py index 339796c9..6c6f9a1f 100755 --- a/run_generator.py +++ b/run_generator.py @@ -13,8 +13,11 @@ import sys import pretrained_networks +from dnnlib.wandb_utils import WandbLogger #---------------------------------------------------------------------------- +wandb_logger = WandbLogger(project='stylegan2', name='generation', + config=None, job_type='generation') def generate_images(network_pkl, seeds, truncation_psi): print('Loading networks from "%s"...' % network_pkl) @@ -30,11 +33,11 @@ def generate_images(network_pkl, seeds, truncation_psi): for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) rnd = np.random.RandomState(seed) - z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component] + z = rnd.randn(1, *Gs.input_shape[1:]) #[minibatch, component] tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel] PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed)) - + wandb_logger.log({'seed%04d.png' % seed: dnnlib.make_run_dir_path('seed%04d.png' % seed)}, flush=True) #---------------------------------------------------------------------------- def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, minibatch_size=4): @@ -69,7 +72,8 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_ print('Saving images...') for (row_seed, col_seed), image in image_dict.items(): PIL.Image.fromarray(image, 'RGB').save(dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed))) - + wandb_logger.log({'Img_%d-%d'%(row_seed, col_seed): dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed))}, + flush=True) print('Saving image grid...') _N, _C, H, W = Gs.output_shape canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') @@ -84,6 +88,7 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_ key = (row_seed, row_seed) canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx)) canvas.save(dnnlib.make_run_dir_path('grid.png')) + wandb_logger.log({'Images Grid': dnnlib.make_run_dir_path('grid.png')}, flush=True) #---------------------------------------------------------------------------- @@ -143,12 +148,11 @@ def main(): args = parser.parse_args() kwargs = vars(args) + wandb_logger.run.config.update(kwargs) subcmd = kwargs.pop('command') - if subcmd is None: print ('Error: missing subcommand. Re-run with --help for usage.') sys.exit(1) - sc = dnnlib.SubmitConfig() sc.num_gpus = 1 sc.submit_target = dnnlib.SubmitTarget.LOCAL @@ -161,10 +165,8 @@ def main(): 'style-mixing-example': 'run_generator.style_mixing_example' } dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs) - #---------------------------------------------------------------------------- if __name__ == "__main__": main() - #---------------------------------------------------------------------------- diff --git a/run_projector.py b/run_projector.py index 5fd89ed7..57143153 100755 --- a/run_projector.py +++ b/run_projector.py @@ -10,15 +10,20 @@ import dnnlib.tflib as tflib import re import sys +from os.path import basename import projector import pretrained_networks from training import dataset from training import misc +from dnnlib.wandb_utils import WandbLogger #---------------------------------------------------------------------------- +wandb_logger = WandbLogger(project='stylegan2', name='generation', config=None, job_type='generation') + def project_image(proj, targets, png_prefix, num_snapshots): + print("proj.num_steps==>",basename(png_prefix)) snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1]) proj.start(targets) @@ -27,6 +32,9 @@ def project_image(proj, targets, png_prefix, num_snapshots): proj.step() if proj.get_cur_step() in snapshot_steps: misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1]) + wandb_logger.log({'Projections/'+basename(png_prefix): + png_prefix + 'step%04d.png' % proj.get_cur_step()},flush=True) + wandb_logger.log({'Targets/'+basename(png_prefix): png_prefix + 'target.png'}, flush=True) print('\r%-30s\r' % '', end='', flush=True) #---------------------------------------------------------------------------- @@ -127,6 +135,7 @@ def main(): sys.exit(1) kwargs = vars(args) + wandb_logger.run.config.update(kwargs) sc = dnnlib.SubmitConfig() sc.num_gpus = 1 sc.submit_target = dnnlib.SubmitTarget.LOCAL diff --git a/training/training_loop.py b/training/training_loop.py index 18aee5b1..41f56435 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -138,13 +138,13 @@ def training_loop( tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus wandb_logger = WandbLogger(project='stylegan2', name='train-'+dataset_args['tfrecord_dir'], - config=dnnlib.submit_config,group="training") + config=dnnlib.submit_config, job_type="training") # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) - wandb_logger.log_image(dnnlib.make_run_dir_path('reals.png'), 'Real Samples') + wandb_logger.log({'Real Samples':dnnlib.make_run_dir_path('reals.png')}) # Construct or load networks. with tf.device('/gpu:0'): @@ -165,7 +165,7 @@ def training_loop( grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) - wandb_logger.log_image(dnnlib.make_run_dir_path('fakes_init.png'), 'Initial Fakes') + wandb_logger.log({'Initial Fakes': dnnlib.make_run_dir_path('fakes_init.png')}) # Setup training inputs. print('Building TensorFlow graph...') @@ -340,7 +340,7 @@ def training_loop( if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) - wandb_logger.log_image(dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), 'fake') + wandb_logger.log({'fake': dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000))}) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) From eddf297734030cecf5dff8380bd4c8ce92d5d62a Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 21 Jan 2021 00:36:49 +0530 Subject: [PATCH 3/3] Update run_projector.py --- run_projector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/run_projector.py b/run_projector.py index 57143153..9be006fb 100755 --- a/run_projector.py +++ b/run_projector.py @@ -23,7 +23,6 @@ wandb_logger = WandbLogger(project='stylegan2', name='generation', config=None, job_type='generation') def project_image(proj, targets, png_prefix, num_snapshots): - print("proj.num_steps==>",basename(png_prefix)) snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1]) proj.start(targets)