-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
executable file
·58 lines (45 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import torch
def main(run, cfg):
from torch.utils.data.distributed import DistributedSampler
from utils.trainer import Trainer
if cfg.reproduce_dire:
dataset = TMDireDataset(cfg.dataset_root)
val_dataset = TMDireDataset(cfg.dataset_test_root)
elif cfg.only_eps:
dataset = TMEPSOnlyDataset(cfg.dataset_root)
val_dataset = TMEPSOnlyDataset(cfg.dataset_root)
elif cfg.only_img:
dataset = TMIMGOnlyDataset(cfg.dataset_root, istrain=True)
val_dataset = TMIMGOnlyDataset(cfg.dataset_test_root, istrain=False)
else:
dataset = TMDistilDireDataset(cfg.dataset_root)
val_dataset = TMDistilDireDataset(cfg.dataset_test_root)
sampler = DistributedSampler(dataset)
val_samlper = DistributedSampler(val_dataset)
dataloader = DataLoader(dataset,
batch_size=cfg.batch_size,
sampler=sampler,)
val_loader = DataLoader(val_dataset,
batch_size=cfg.batch_size,
sampler=val_samlper,)
trainer = Trainer(cfg, dataloader, val_loader, run, local_rank, True, world_size, cfg.kd)
if cfg.pretrained_weights:
trainer.load_networks(cfg.pretrained_weights)
trainer.train()
if __name__ == "__main__":
import torch.distributed as dist
import os
import wandb
from torch.utils.data import DataLoader
from dataset import TMDistilDireDataset, TMDireDataset, TMEPSOnlyDataset, TMIMGOnlyDataset, JOINEDDistilDireDataset
dist.init_process_group(backend='nccl', init_method='env://')
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
torch.cuda.set_device(local_rank)
dist.barrier()
from utils.config import cfg
run = None
if local_rank == 0:
run = wandb.init(project=f'distil-dire', config=cfg, dir=cfg.exp_dir)
main(run, cfg)