-
Notifications
You must be signed in to change notification settings - Fork 6
/
demo.py
241 lines (198 loc) · 10.4 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Copyright (c) 2.2022. Yinyu Nie
# License: MIT
import torch
import wandb
from time import time
import os
import numpy as np
from net_utils.utils import CheckpointIO, AverageMeter
from models.optimizers import load_optimizer, load_scheduler
from net_utils.utils import load_device, load_model, load_trainer, load_tester
from models.ours.dataloader import THREEDFRONT, ScanNet, collate_fn
import h5py
class Demo(object):
def __init__(self, cfg):
self.cfg = cfg
self.is_master = cfg.is_master
'''Load save path and checkpoint handler.'''
cfg.info('Data save path: %s' % (os.getcwd()))
cfg.info('Loading checkpoint handler')
self.checkpoint = CheckpointIO(cfg, self.is_master)
'''Load device'''
cfg.info('Loading device settings.')
self.device = load_device(cfg)
'''Load data'''
cfg.info('Reading demo samples.')
self.all_samples = self.read_samples(self.cfg.demo_samples)
'''Load model'''
cfg.info('Loading model')
self.net = load_model(cfg, device=self.device)
self.checkpoint.register_modules(net=self.net)
cfg.info(self.net)
'''Freeze network part'''
for net_type, subnet in self.net.items():
if net_type in ['latent_input']: continue
self.cfg.info('%s is frozen.' % (net_type))
for param in subnet.parameters():
param.requires_grad = False
'''Read network weights (finetune mode)'''
self.checkpoint.parse_checkpoint(device=self.device)
'''Load sub trainer for a specific method.'''
cfg.info('Loading method tester.')
self.subtester = load_tester(cfg=cfg, net=self.net, device=self.device)
'''Output network size'''
self.subtester.show_net_n_params()
# put logger where it belongs
if self.is_master and cfg.config.log.if_wandb:
cfg.info('Loading wandb.')
wandb.init(project=cfg.config.method, name=cfg.config.exp_name, config=cfg.config)
# wandb.watch(self.net)
def log_wandb(self, loss, phase):
dict_ = dict()
for key, value in loss.items():
dict_[phase + '/' + key] = value
wandb.log(dict_)
def read_samples(self, sample_files):
batch_id = self.cfg.config.demo.batch_id
batch_num = self.cfg.config.demo.batch_num
sublist = np.array_split(np.arange(len(sample_files)), batch_num)[batch_id]
sample_files = [sample_files[idx] for idx in sublist]
samples = []
for sample_file in sample_files:
processed_sample = self.read_data(sample_file, self.cfg.config.data.dataset)
samples.append(processed_sample)
return samples
def read_sample(self, sample_name, dataset_name):
if dataset_name == '3D-Front':
room_type, img, cam_K, cam_T, insts = THREEDFRONT.parse_hdf5(self.cfg, sample_name)
image_size = self.cfg.image_size
elif dataset_name == 'ScanNet':
room_type, img, cam_K, cam_T, insts, image_size = ScanNet.parse_hdf5(self.cfg, sample_name)
else:
raise NotImplementedError
return room_type, img, cam_K, cam_T, insts, image_size
def load_unique_inst_marks(self, sample_path, dataset_name):
'''read data'''
if dataset_name == '3D-Front':
room_name = '_'.join(os.path.basename(sample_path).split('_')[:-1])
unique_inst_marks = self.cfg.unique_inst_mark[room_name]
elif dataset_name == 'ScanNet':
with h5py.File(sample_path, "r") as sample_data:
unique_inst_marks = sample_data['unique_inst_marks'][:].tolist()
else:
raise NotImplementedError
return unique_inst_marks
def read_data(self, sample_path, dataset_name):
'''read views data from 3d front'''
# sample other view
room_name = '_'.join(sample_path.name.split('_')[:-1])
candidates = [file for file in sample_path.parent.iterdir() if room_name in file.name and file.name != sample_path.name]
if len(candidates):
view_ids = np.random.choice(len(candidates), self.cfg.config.data.n_views - 1,
replace=len(candidates) < self.cfg.config.data.n_views - 1)
view_files = [sample_path] + [candidates[idx] for idx in view_ids]
else:
view_files = [sample_path] * self.cfg.config.data.n_views
parsed_data = []
for view_file in view_files:
room_type, img, cam_K, cam_T, insts, image_size = self.read_sample(view_file, dataset_name)
sample_name = '.'.join(os.path.basename(view_file).split('.')[:-1])
parsed_data.append((img, cam_K, cam_T, insts, image_size, sample_name))
room_type_idx = self.cfg.room_types.index(room_type)
unique_marks = sorted(list(set(sum([view[3]['inst_marks'] for view in parsed_data], []))))
# re-organize instances following track ids
parsed_data = THREEDFRONT.track_insts(parsed_data, unique_marks)
keywords = ['sample_name', 'cam_K', 'image_size', 'cam_T', 'box2ds_tr', 'inst_marks']
if self.cfg.config.start_deform:
keywords.append('masks_tr')
if self.cfg.config.data.dataset == 'ScanNet':
keywords.append('render_mask_tr')
views_data = []
for parsed_view in parsed_data:
view_data = THREEDFRONT.get_view_data(self.cfg, parsed_view, self.cfg.config.data.downsample_ratio)
view_data = {**{k: view_data[k] for k in keywords},
**{'room_idx': 0, 'max_len': len(unique_marks), 'room_type_idx': room_type_idx}}
views_data.append(view_data)
return collate_fn([views_data])
def run(self):
'''Finetune latent codes and output results'''
'''Start to finetune latent codes'''
self.cfg.info('Start to finetune latent codes.')
'''Time meter setup.'''
sample_timemeter = AverageMeter()
epoch_timemeter = AverageMeter()
phase = 'train'
start_epoch = 0
total_epochs = self.cfg.config.demo.epochs
# ---------------------------------------------------------------------------------------
for sample in self.all_samples:
sample_start = time()
torch.cuda.empty_cache()
self.cfg.info('=' * 100)
self.cfg.info('Processing: %s.' % (sample['sample_name'][0][0]))
self.cfg.info('Loading optimizer.')
'''Load optimizer'''
optimizer = load_optimizer(config=self.cfg.config, net=self.net)
'''Load scheduler'''
self.cfg.info('Loading optimizer scheduler.')
scheduler = load_scheduler(cfg=self.cfg, optimizer=optimizer)
'''Load sub trainer for a specific method.'''
self.cfg.info('Loading method trainer.')
subtrainer = load_trainer(cfg=self.cfg, net=self.net, optimizer=optimizer, device=self.device)
'''Freeze network'''
# set mode
subtrainer.set_mode(phase)
# freeze the network part
for net_type, subnet in self.net.items():
if net_type in ['latent_input']: continue
for child in subnet.children():
child.train(False)
'''Start to finetune latent code'''
self.cfg.info('Start to finetune latent code.')
# ---------------------------------------------------------------------------------------
min_eval_loss = self.checkpoint.get('min_loss')
loss = {'total': min_eval_loss}
pred_gt_matching = None
if_mask_loss = False
losses = []
for epoch in range(start_epoch, total_epochs):
epoch_start = time()
if (epoch % self.cfg.config.log.print_step) == 0:
self.cfg.info('-' * 100)
self.cfg.info('Epoch (%d/%s):' % (epoch, total_epochs - 1))
subtrainer.show_lr()
if epoch > self.cfg.config.demo.mask_flag:
pred_gt_matching = extra_output['pred_gt_matching']
if_mask_loss = True
loss, extra_output = subtrainer.train_step(sample, stage='latent_only',
start_deform=self.cfg.config.start_deform,
return_matching=True, pred_gt_matching=pred_gt_matching, if_mask_loss=if_mask_loss)
losses.append(loss)
if loss['total'] < min_eval_loss:
min_eval_loss = loss['total']
'''Display epoch info'''
if (epoch % self.cfg.config.log.print_step) == 0:
epoch_timemeter.update(time() - epoch_start)
self.cfg.info('Latent_lr: {Latent_lr:s} | {phase:s} | Epoch: [{0}/{1}] | Loss: {loss:s}\
Epoch Time {epoch_time:.3f}'.format(
epoch, total_epochs, phase='finetune', loss=str(loss),
epoch_time=epoch_timemeter.avg, Latent_lr=str(scheduler['latent_input'].get_last_lr()[:2])))
if self.is_master and self.cfg.config.log.if_wandb:
self.log_wandb(loss, sample['sample_name'][0][0])
scheduler['latent_input'].step()
sample_timemeter.update(time() - sample_start)
self.cfg.info('-' * 100)
self.cfg.info('{sample:s}: Best loss is {best_loss:.3f} | Last loss is {last_loss:.3f} | Avg fitting time: {time:.3f}'.format(
sample=sample['sample_name'][0][0], best_loss=min_eval_loss, last_loss=loss['total'], time=sample_timemeter.avg))
'''Output visualizations'''
self.cfg.info('=' * 100)
self.cfg.info('Export visualizations.')
# set mode
self.subtester.set_mode('test')
with torch.no_grad():
_, est_data = self.subtester.test_step(sample, start_deform=self.cfg.config.start_deform, pred_gt_matching=extra_output['pred_gt_matching'])
self.subtester.visualize_step(phase, iter, sample, est_data,
dump_dir=self.cfg.config.demo.output_dir)
# ---------------------------------------------------------------------------------------
wandb.finish()
self.cfg.info('Testing finished.')