-
Notifications
You must be signed in to change notification settings - Fork 9
/
train_chembl_multitask.py
222 lines (187 loc) · 7.53 KB
/
train_chembl_multitask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from onnxruntime.quantization import quantize_dynamic
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.utils.data as D
import pytorch_lightning as pl
import tables as tb
from sklearn.metrics import (
matthews_corrcoef,
confusion_matrix,
f1_score,
roc_auc_score,
accuracy_score,
roc_auc_score,
)
from sklearn.model_selection import KFold
from collections import Counter
import json
CHEMBL_VERSION = 34
PATH = "."
DATA_FILE = f"mt_data_{CHEMBL_VERSION}.h5"
N_WORKERS = 6 # prefetches data in parallel to have batches ready for traning
BATCH_SIZE = 32 # https://twitter.com/ylecun/status/989610208497360896
LR = 4 # Learning rate. Big value because of the way we are weighting the targets
FP_SIZE = 1024
# PyTorch Dataset that reads batches from a PyTables file
class ChEMBLDataset(D.Dataset):
def __init__(self, file_path):
self.file_path = file_path
with tb.open_file(self.file_path, mode="r") as t_file:
self.length = t_file.root.fps.shape[0]
self.n_targets = t_file.root.labels.shape[1]
def __len__(self):
return self.length
def __getitem__(self, index):
with tb.open_file(self.file_path, mode="r") as t_file:
structure = t_file.root.fps[index]
labels = t_file.root.labels[index]
return structure, labels
class ChEMBLMultiTask(pl.LightningModule):
"""
Architecture borrowed from: https://arxiv.org/abs/1502.02072
"""
def __init__(self, n_tasks, weights=None):
super().__init__()
self.n_tasks = n_tasks
self.fc1 = nn.Linear(FP_SIZE, 2000)
self.fc2 = nn.Linear(2000, 100)
self.dropout = nn.Dropout(0.25)
self.test_step_outputs = []
# add an independent output for each task in the output layer
for n_m in range(n_tasks):
self.add_module(f"y{n_m}o", nn.Linear(100, 1))
if weights is not None:
self.criterion = [
nn.BCELoss(weight=w) for w in torch.tensor(weights).float()
]
else:
self.criterion = [nn.BCELoss() for _ in range(n_tasks)]
def forward(self, x):
h1 = self.dropout(F.relu(self.fc1(x)))
h2 = F.relu(self.fc2(h1))
out = [
torch.sigmoid(getattr(self, f"y{n_m}o")(h2)) for n_m in range(self.n_tasks)
]
return out
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=LR)
return optimizer
def training_step(self, batch, batch_idx):
fps, labels = batch
logits = self.forward(fps)
loss = torch.tensor(0.0)
for j, crit in enumerate(self.criterion):
# mask keeping labeled molecules for each target
mask = labels[:, j] >= 0.0
if len(labels[:, j][mask]) != 0:
# the loss is the sum of all targets loss
# there are labeled samples for this target in this batch, so we add it's loss
loss += crit(logits[j][mask], labels[:, j][mask].view(-1, 1))
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
fps, labels = batch
out = self.forward(fps)
y = []
y_hat = []
y_hat_proba = []
for j, out in enumerate(out):
mask = labels[:, j] >= 0.0
y_pred = torch.where(out[mask] > 0.5, torch.ones(1), torch.zeros(1)).view(
1, -1
)
if y_pred.shape[1] > 0:
for l in labels[:, j][mask].long().tolist():
y.append(l)
for p in y_pred.view(-1, 1).tolist():
y_hat.append(int(p[0]))
for p in out[mask].view(-1, 1).tolist():
y_hat_proba.append(float(p[0]))
tn, fp, fn, tp = confusion_matrix(y, y_hat).ravel()
sens = tp / (tp + fn)
spec = tn / (tn + fp)
prec = tp / (tp + fp)
f1 = f1_score(y, y_hat)
acc = accuracy_score(y, y_hat)
mcc = matthews_corrcoef(y, y_hat)
auc = roc_auc_score(y, y_hat_proba)
metrics = {
"test_acc": torch.tensor(acc),
"test_sens": torch.tensor(sens),
"test_spec": torch.tensor(spec),
"test_prec": torch.tensor(prec),
"test_f1": torch.tensor(f1),
"test_mcc": torch.tensor(mcc),
"test_auc": torch.tensor(auc),
}
self.log_dict(metrics)
self.test_step_outputs.append(metrics)
return metrics
def on_test_epoch_end(self):
sums = Counter()
counters = Counter()
for itemset in self.test_step_outputs:
sums.update(itemset)
counters.update(itemset.keys())
metrics = {x: float(sums[x]) / counters[x] for x in sums.keys()}
return metrics
if __name__ == "__main__":
# each task loss is weighted inversely proportional to its number of datapoints, borrowed from:
# from: http://www.datascienceassn.org/sites/default/files/Deep%20Learning%20as%20an%20Opportunity%20in%20Virtual%20Screening.pdf
with tb.open_file(f"{PATH}/{DATA_FILE}", mode="r") as t_file:
weights = t_file.root.weights[:]
dataset = ChEMBLDataset(f"{PATH}/{DATA_FILE}")
indices = list(range(len(dataset)))
metrics = []
kfold = KFold(n_splits=5, shuffle=True)
for train_idx, test_idx in kfold.split(indices):
train_sampler = D.sampler.SubsetRandomSampler(train_idx)
test_sampler = D.sampler.SubsetRandomSampler(test_idx)
train_loader = DataLoader(
dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS, sampler=train_sampler
)
test_loader = DataLoader(
dataset, batch_size=1000, num_workers=N_WORKERS, sampler=test_sampler
)
model = ChEMBLMultiTask(len(weights), weights)
# this shallow model trains quicker in CPU
trainer = pl.Trainer(max_epochs=3, accelerator="cpu")
trainer.fit(model, train_dataloaders=train_loader)
mm = trainer.test(dataloaders=test_loader)
metrics.append(mm)
# average folds metrics
metrics = [item for sublist in metrics for item in sublist]
sums = Counter()
counters = Counter()
for itemset in metrics:
sums.update(itemset)
counters.update(itemset.keys())
performance = {x: float(sums[x]) / counters[x] for x in sums.keys()}
with open(f"performance_{CHEMBL_VERSION}.json", "w") as f:
json.dump(performance, f)
# Train the model with the whole dataset and export to ONNX format
final_train_sampler = D.sampler.SubsetRandomSampler(indices)
final_train_loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
num_workers=N_WORKERS,
sampler=final_train_sampler,
)
model = ChEMBLMultiTask(len(weights), weights)
# this shallow model trains quicker in CPU
trainer = pl.Trainer(max_epochs=3, accelerator="cpu")
trainer.fit(model, train_dataloaders=final_train_loader)
with tb.open_file(f"mt_data_{CHEMBL_VERSION}.h5", mode="r") as t_file:
output_names = t_file.root.target_chembl_ids[:]
model.to_onnx(
f"./chembl_{CHEMBL_VERSION}_multitask.onnx",
torch.ones(FP_SIZE),
export_params=True,
input_names=["input"],
output_names=output_names,
)
model_fp32 = f"./chembl_{CHEMBL_VERSION}_multitask.onnx"
model_quant = f"./chembl_{CHEMBL_VERSION}_multitask_q8.onnx"
quantized_model = quantize_dynamic(model_fp32, model_quant)