-
Notifications
You must be signed in to change notification settings - Fork 27
/
pretrain.py
72 lines (57 loc) · 2.68 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/9/16 11:20
# @Author : Huatao
# @Email : [email protected]
# @File : pretrain.py
# @Description :
import argparse
import sys
import numpy as np
import torch
import torch.nn as nn
import copy
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import models, train
from config import MaskConfig, TrainConfig, PretrainModelConfig
from models import LIMUBertModel4Pretrain
from utils import set_seeds, get_device \
, LIBERTDataset4Pretrain, handle_argv, load_pretrain_data_config, prepare_classifier_dataset, \
prepare_pretrain_dataset, Preprocess4Normalization, Preprocess4Mask
def main(args, training_rate):
data, labels, train_cfg, model_cfg, mask_cfg, dataset_cfg = load_pretrain_data_config(args)
pipeline = [Preprocess4Normalization(model_cfg.feature_num), Preprocess4Mask(mask_cfg)]
# pipeline = [Preprocess4Mask(mask_cfg)]
data_train, label_train, data_test, label_test = prepare_pretrain_dataset(data, labels, training_rate, seed=train_cfg.seed)
data_set_train = LIBERTDataset4Pretrain(data_train, pipeline=pipeline)
data_set_test = LIBERTDataset4Pretrain(data_test, pipeline=pipeline)
data_loader_train = DataLoader(data_set_train, shuffle=True, batch_size=train_cfg.batch_size)
data_loader_test = DataLoader(data_set_test, shuffle=False, batch_size=train_cfg.batch_size)
model = LIMUBertModel4Pretrain(model_cfg)
criterion = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(params=model.parameters(), lr=train_cfg.lr)
device = get_device(args.gpu)
trainer = train.Trainer(train_cfg, model, optimizer, args.save_path, device)
def func_loss(model, batch):
mask_seqs, masked_pos, seqs = batch #
seq_recon = model(mask_seqs, masked_pos) #
loss_lm = criterion(seq_recon, seqs) # for masked LM
return loss_lm
def func_forward(model, batch):
mask_seqs, masked_pos, seqs = batch
seq_recon = model(mask_seqs, masked_pos)
return seq_recon, seqs
def func_evaluate(seqs, predict_seqs):
loss_lm = criterion(predict_seqs, seqs)
return loss_lm.mean().cpu().numpy()
if hasattr(args, 'pretrain_model'):
trainer.pretrain(func_loss, func_forward, func_evaluate, data_loader_train, data_loader_test
, model_file=args.pretrain_model)
else:
trainer.pretrain(func_loss, func_forward, func_evaluate, data_loader_train, data_loader_test, model_file=None)
if __name__ == "__main__":
mode = "base"
args = handle_argv('pretrain_' + mode, 'pretrain.json', mode)
training_rate = 0.8
main(args, training_rate)