forked from miraflow/DistilDIRE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
executable file
·23 lines (19 loc) · 997 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import wandb
from utils.config import cfg
def main(run, cfg):
from utils.trainer import Trainer
from torch.utils.data import DataLoader
from dataset import TMEPSOnlyDataset, TMIMGOnlyDataset, TMDistilDireDataset
print(cfg.dataset_test_root)
# dataset =TMEPSOnlyDataset(cfg.dataset_test_root, False)
dataset = TMIMGOnlyDataset(cfg.dataset_test_root, istrain=False)
# dataset = TMDistilDireDataset(cfg.dataset_test_root, prepared_dire=True)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=True, num_workers=2)
trainer = Trainer(cfg, dataloader, dataloader, run, 0, False, 1)
assert len(cfg.pretrained_weights) != 0, "Give proper checkpoint path"
trainer.load_networks(cfg.pretrained_weights)
trainer.validate(False, save=True, save_name=f"{cfg.root_dir}/{cfg.datasets_test}_{cfg.pretrained_weights.split('/')[-1]}_results.txt")
if __name__ == "__main__":
main(None, cfg)