-
Notifications
You must be signed in to change notification settings - Fork 0
/
ASE_model.py
122 lines (91 loc) · 4.91 KB
/
ASE_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers.models.bart.modeling_bart import BartClassificationHead
class ASEModel(nn.Module):
def __init__(self, bart, tokenizer):
super(ASEModel, self).__init__()
# bart model
self.model = bart
self.tokenizer = tokenizer
# uncertainty parameters
self.loss_weights = nn.Parameter(torch.ones(4))
# classification head provided by huggingface.
# (input_dim, inner_dim, num_classes, dropout)
self.classification_head = BartClassificationHead(
768, 768, 1, 0.1
)
init.xavier_normal_(self.classification_head.dense.weight)
init.xavier_normal_(self.classification_head.out_proj.weight)
self.classifier = nn.Linear(768, 1)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.dropout = nn.Dropout(0.1)
def generation_forward(self, encoder_input_ids, encode_attention_mask, labels):
# obtain generation loss
# bart will automatically shift the labels to get decoder_input_ids
bart_inputs = {'input_ids': encoder_input_ids, 'attention_mask': encode_attention_mask,
'output_hidden_states': True, 'labels': labels}
bart_outputs = self.model(**bart_inputs)
loss = bart_outputs.loss
return loss
def weighted_loss(self, loss, index):
# obtain uncertainty loss
return (loss / (self.loss_weights[index] * 2)) + (self.loss_weights[index] + 1).log()
def hinge_loss(self, scores, margin, mask):
# obtain hinge ranking loss
loss = torch.nn.functional.relu(margin - (torch.unsqueeze(scores[:, 0], -1) - scores[:, 1:]) * mask)
return torch.mean(loss)
def forward(self, batch_data, is_test=False):
if is_test:
# only use encoder while inferencing.
encoder_input_ids_rank = batch_data["encoder_input_ids_rank"]
encode_attention_mask_rank = batch_data["encode_attention_mask_rank"]
eos_position = batch_data["eos_position"]
bart_inputs = {'input_ids': encoder_input_ids_rank, 'attention_mask': encode_attention_mask_rank,
'output_hidden_states': True}
bart_outputs = self.model.model(**bart_inputs)
encoder_hidden = bart_outputs.encoder_last_hidden_state
# the ouput of [CLS].
classification_head_token = (eos_position == 1)
# go through MLP.
eos_hidden = encoder_hidden[classification_head_token,:]
y_pred = self.classification_head(eos_hidden).squeeze(1)
return y_pred
else:
scores = []
gen_losses = []
loss_mask = None
for batch in batch_data:
encoder_input_ids_rank = batch["encoder_input_ids_rank"]
encode_attention_mask_rank = batch["encode_attention_mask_rank"]
encoder_input_ids_gen_fq = batch["encoder_input_ids_gen_fq"]
encoder_input_ids_gen_cd = batch["encoder_input_ids_gen_cd"]
encoder_input_ids_gen_sq = batch["encoder_input_ids_gen_sq"]
eos_position = batch["eos_position"]
next_q_labels = batch["next_q_labels"]
click_doc_labels = batch["click_doc_labels"]
previous_q_labels = batch["previous_q_labels"]
simq_labels = batch["simq_labels"]
loss_mask = batch["loss_mask"]
# Ranking
bart_inputs = {'input_ids': encoder_input_ids_rank, 'attention_mask': encode_attention_mask_rank,
'output_hidden_states': True}
bart_outputs = self.model.model(**bart_inputs)
encoder_hidden = bart_outputs.encoder_last_hidden_state
# the ouput of [CLS].
classification_head_token = (eos_position == 1)
eos_hidden = encoder_hidden[classification_head_token,:]
y_pred = self.classification_head(eos_hidden)
scores.append(y_pred)
# Generation Losses of three tasks.
gen_loss1 = self.generation_forward(encoder_input_ids_gen_fq, encode_attention_mask_rank, next_q_labels)
gen_loss2 = self.generation_forward(encoder_input_ids_gen_cd, encode_attention_mask_rank, click_doc_labels)
gen_loss3 = self.generation_forward(encoder_input_ids_gen_sq, encode_attention_mask_rank, simq_labels)
gen_loss = self.weighted_loss(gen_loss1, 1) + self.weighted_loss(gen_loss2, 2) + self.weighted_loss(gen_loss3, 3)
gen_losses.append(gen_loss)
# Ranking Loss
batch_scores = torch.cat(scores, dim = -1)
ranking_loss = self.hinge_loss(batch_scores, 1, loss_mask)
w_ranking_loss = self.weighted_loss(ranking_loss, 0)
return sum(gen_losses)/len(gen_losses), w_ranking_loss