Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Weights and Biases Experiment tracking #18

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dnnlib/tflib/autosummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,6 @@ def save_summaries(file_writer, global_step=None):
file_writer.add_summary(layout)
with tf.device(None), tf.control_dependencies(None):
_merge_op = tf.summary.merge_all()

file_writer.add_summary(_merge_op.eval(), global_step)
_merge_op_eval = _merge_op.eval()
file_writer.add_summary(_merge_op_eval, global_step)
return _merge_op_eval
36 changes: 36 additions & 0 deletions dnnlib/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import imghdr

import wandb
from wandb.integration.tensorboard import tf_summary_to_dict

class WandbLogger():
def __init__(self, project, name, config, job_type=None):
self.run = wandb.init(project=project, name=name, config=config, job_type=job_type) if not wandb.run else wandb.run
self.log_dict = {}

def log_tf_summary(self, summary):
tf_log_dict = wandb.integration.tensorboard.tf_summary_to_dict(summary)
if tf_log_dict:
for key, value in tf_log_dict.items():
self.log_dict[key] = value

def log_model_artifact(self, path, step):
model_artifact = wandb.Artifact('run_'+wandb.run.id+'_checkpoints', type='model', metadata={'cur_nimg': step})
model_artifact.add_file(path, name='network-snapshot-%06d.pkl' % (step))
wandb.log_artifact(model_artifact)

def flush(self):
wandb.log(self.log_dict)
self.log_dict = {}

def log(self, log_dict, flush=False):
for key, value in log_dict.items():
if imghdr.what(value): # Check if the value is an image
self.log_dict[key] = wandb.Image(value)
else:
self.log_dict[key] = value
if flush:
self.flush()



16 changes: 9 additions & 7 deletions run_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
import sys

import pretrained_networks
from dnnlib.wandb_utils import WandbLogger

#----------------------------------------------------------------------------
wandb_logger = WandbLogger(project='stylegan2', name='generation',
config=None, job_type='generation')

def generate_images(network_pkl, seeds, truncation_psi):
print('Loading networks from "%s"...' % network_pkl)
Expand All @@ -30,11 +33,11 @@ def generate_images(network_pkl, seeds, truncation_psi):
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
rnd = np.random.RandomState(seed)
z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
z = rnd.randn(1, *Gs.input_shape[1:]) #[minibatch, component]
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))

wandb_logger.log({'seed%04d.png' % seed: dnnlib.make_run_dir_path('seed%04d.png' % seed)}, flush=True)
#----------------------------------------------------------------------------

def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, minibatch_size=4):
Expand Down Expand Up @@ -69,7 +72,8 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_
print('Saving images...')
for (row_seed, col_seed), image in image_dict.items():
PIL.Image.fromarray(image, 'RGB').save(dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed)))

wandb_logger.log({'Img_%d-%d'%(row_seed, col_seed): dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed))},
flush=True)
print('Saving image grid...')
_N, _C, H, W = Gs.output_shape
canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
Expand All @@ -84,6 +88,7 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_
key = (row_seed, row_seed)
canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
canvas.save(dnnlib.make_run_dir_path('grid.png'))
wandb_logger.log({'Images Grid': dnnlib.make_run_dir_path('grid.png')}, flush=True)

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -143,12 +148,11 @@ def main():

args = parser.parse_args()
kwargs = vars(args)
wandb_logger.run.config.update(kwargs)
subcmd = kwargs.pop('command')

if subcmd is None:
print ('Error: missing subcommand. Re-run with --help for usage.')
sys.exit(1)

sc = dnnlib.SubmitConfig()
sc.num_gpus = 1
sc.submit_target = dnnlib.SubmitTarget.LOCAL
Expand All @@ -161,10 +165,8 @@ def main():
'style-mixing-example': 'run_generator.style_mixing_example'
}
dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs)

#----------------------------------------------------------------------------

if __name__ == "__main__":
main()

#----------------------------------------------------------------------------
8 changes: 8 additions & 0 deletions run_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
import dnnlib.tflib as tflib
import re
import sys
from os.path import basename

import projector
import pretrained_networks
from training import dataset
from training import misc
from dnnlib.wandb_utils import WandbLogger

#----------------------------------------------------------------------------

wandb_logger = WandbLogger(project='stylegan2', name='generation', config=None, job_type='generation')

def project_image(proj, targets, png_prefix, num_snapshots):
snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int))
misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1])
Expand All @@ -27,6 +31,9 @@ def project_image(proj, targets, png_prefix, num_snapshots):
proj.step()
if proj.get_cur_step() in snapshot_steps:
misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1])
wandb_logger.log({'Projections/'+basename(png_prefix):
png_prefix + 'step%04d.png' % proj.get_cur_step()},flush=True)
wandb_logger.log({'Targets/'+basename(png_prefix): png_prefix + 'target.png'}, flush=True)
print('\r%-30s\r' % '', end='', flush=True)

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -127,6 +134,7 @@ def main():
sys.exit(1)

kwargs = vars(args)
wandb_logger.run.config.update(kwargs)
sc = dnnlib.SubmitConfig()
sc.num_gpus = 1
sc.submit_target = dnnlib.SubmitTarget.LOCAL
Expand Down
17 changes: 13 additions & 4 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from training import dataset
from training import misc
from metrics import metric_base
from dnnlib.wandb_utils import WandbLogger

#----------------------------------------------------------------------------
# Just-in-time processing of training images before feeding them to the networks.
Expand Down Expand Up @@ -133,15 +134,18 @@ def training_loop(
resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting.
resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training?

# Initialize dnnlib and TensorFlow.
# Initialize dnnlib, TensorFlow and WandbLogger
tflib.init_tf(tf_config)
num_gpus = dnnlib.submit_config.num_gpus
wandb_logger = WandbLogger(project='stylegan2', name='train-'+dataset_args['tfrecord_dir'],
config=dnnlib.submit_config, job_type="training")

# Load training set.
training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)

wandb_logger.log({'Real Samples':dnnlib.make_run_dir_path('reals.png')})

# Construct or load networks.
with tf.device('/gpu:0'):
if resume_pkl is None or resume_with_new_nets:
Expand All @@ -161,7 +165,8 @@ def training_loop(
grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size)

wandb_logger.log({'Initial Fakes': dnnlib.make_run_dir_path('fakes_init.png')})

# Setup training inputs.
print('Building TensorFlow graph...')
with tf.name_scope('Inputs'), tf.device('/cpu:0'):
Expand Down Expand Up @@ -335,14 +340,18 @@ def training_loop(
if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
wandb_logger.log({'fake': dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000))})
if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
misc.save_pkl((G, D, Gs), pkl)
wandb_logger.log_model_artifact(pkl, cur_nimg // 1000)
metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config)

# Update summaries and RunContext.
metrics.update_autosummaries()
tflib.autosummary.save_summaries(summary_log, cur_nimg)
summary = tflib.autosummary.save_summaries(summary_log, cur_nimg)
wandb_logger.log_tf_summary(summary)
wandb_logger.flush()
dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

Expand Down