-
Notifications
You must be signed in to change notification settings - Fork 1
/
process_utils.py
235 lines (210 loc) · 8.12 KB
/
process_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import os
import json
import json5
import logging
from dataclasses import dataclass
from transformers import AutoTokenizer
from tqdm import tqdm
from typing import List, Union
from itertools import chain
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class XPGTItem():
task: str
instruction: str
input: str
ref_output: Union[str, List[str]]
hypo_output: str
# Message map functions
def default_msg_map(cur_message: dict, messages: List[dict]):
""" Map the text and old messages to the new messages for query
Args:
text (str): the prompt text
messages (List[dict]): the messages list before this query
Returns:
prompt (str): the prompt text
"""
new_messages = messages + [{
"role": cur_message['role'],
"content": cur_message['content']}
]
return new_messages
# Postprocess functions
def default_postprocess(content: str):
return content
def json_postprocess(content: str):
try:
# find the json content
json_content = content[content.find("{"):content.rfind("}") + 1]
json_content = json.loads(json_content)
return json_content
except json.decoder.JSONDecodeError:
try:
json_content = json5.loads(json_content)
return json_content
except Exception:
return content
tokenizer = None
def truncate_texts(texts: Union[List[str], List[List[str]]], max_length: int = None):
"""
Truncate the texts to the max length.
Args:
texts (List[str] or List[List[str]]): The list of texts.
max_length (int): The max length.
Returns:
List[str]: The truncated texts.
"""
if max_length is None:
return texts
if isinstance(texts[0], list) and \
(
all([len(x) == 0 for x in texts]) or
all([x is None for x in list(chain(*texts))])
) or isinstance(texts[0], str) and \
all([x is None for x in list(chain(texts))]):
logging.warning("All texts are None, skip truncating")
return texts
# using llama tokenizer by default
global tokenizer
disable_tqdm = len(texts) < 1000
logging.warning(f"Truncating texts to max length {max_length}")
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-hf", use_auth_token=True)
# ...
token_ids = []
for text in tqdm(texts, desc="Truncating texts (tokenizing)", disable=disable_tqdm):
if isinstance(text, list):
token_ids.append(
[tokenizer.encode(x, add_special_tokens=False) for x in text])
else:
token_ids.append(tokenizer.encode(text, add_special_tokens=False))
# ...
truncated_texts = []
for i, _token_ids in tqdm(enumerate(token_ids), desc="Truncating texts (truncating)", disable=disable_tqdm):
if (len(_token_ids)) and isinstance(_token_ids[0], list):
truncated_texts.append([])
for _token_id in _token_ids:
if len(_token_id) > max_length:
truncated_text = tokenizer.decode(
_token_id[:max_length], skip_special_tokens=True)
truncated_text = truncated_text + " ..."
else:
truncated_text = tokenizer.decode(
_token_id, skip_special_tokens=True)
truncated_texts[i].append(truncated_text)
else:
if len(_token_ids) > max_length:
truncated_text = tokenizer.decode(
_token_ids[:max_length], skip_special_tokens=True)
truncated_text = truncated_text + " ..."
else:
truncated_text = tokenizer.decode(
_token_ids, skip_special_tokens=True)
truncated_texts.append(truncated_text)
return truncated_texts
def truncate_items(items: List[XPGTItem], max_lengths):
"""
Truncate the texts in the items to the max length.
Args:
items (List[XPGTItem]): The list of items.
max_length (int): The max length.
Returns:
List[XPGTItem]: The truncated items.
"""
truncated_inputs = truncate_texts(
[item.input for item in items], max_lengths.get("input", None))
truncated_insts = truncate_texts(
[item.instruction for item in items], max_lengths.get("instruction", None))
truncated_ref_outputs = truncate_texts(
[item.ref_output for item in items], max_lengths.get("ref_output", None))
truncated_hypo_outputs = truncate_texts(
[item.hypo_output for item in items], max_lengths.get("hypo_output", None))
for i, item in enumerate(items):
item.instruction = truncated_insts[i]
item.input = truncated_inputs[i]
item.ref_output = truncated_ref_outputs[i]
item.hypo_output = truncated_hypo_outputs[i]
return items
def get_query_messages(messages: List[dict], queried_messages: List[dict]):
"""
Args:
messages (List[dict]): the messages list to add for query
queried_messages (List[dict]): the messages list already queried, which contains the query responses also,
Returns:
new_messages (List[dict]): the new messages list to query
postprocess (function): the postprocess function for the query response
"""
if len(queried_messages) == 0:
last_prompt_idx = -1
else:
assert len(
queried_messages) >= 2, "queried_messages should have at least 2 messages, i.e., the user (system) and the response"
last_prompt = queried_messages[-2]['content']
prompt_texts = [x['content'] for x in messages]
last_prompt_idx = prompt_texts.index(last_prompt)
if last_prompt_idx == len(messages) - 1:
return None
new_messages = queried_messages.copy()
for idx in range(last_prompt_idx + 1, len(messages)):
new_messages = messages[idx]["map_func"](messages[idx], new_messages)
if messages[idx]["do_query"]:
break
return new_messages, messages[idx]["postprocess"]
def get_xgptscore_from_json(json_content: dict):
"""
Args:
json_content (dict): the json content
Returns:
xgptscore (float): the xgptscore, i.e. the sum of the reduction scores for all the errors
"""
if isinstance(json_content, str):
return None
try:
xgptscore = 0
for error in json_content['errors'].values():
if error['score_reduction'] == "N/A":
continue
xgptscore -= error['score_reduction']
return xgptscore
except Exception:
return None
def get_xgptscore_from_json_star(json_content: dict):
"""
Args:
json_content (dict): the json content
Returns:
xgptscore (float): the xgptscore, i.e. the sum of the reduction scores for all the errors
"""
xgptscore = 0
res = {}
for aspect_key, aspect in json_content.items():
if isinstance(aspect, dict):
score = aspect['Score']
try:
score = float(score)
except Exception:
score = 0
xgptscore += score
res["xgptscore_" + aspect_key] = score
res["xgptscore"] = xgptscore
return res
def get_xgptscore_from_json_per_aspect(json_content: dict):
"""
Args:
json_content (dict): the json content
Returns:
xgptscore (float): the xgptscore, i.e. the sum of the reduction scores for all the errors
"""
if not isinstance(json_content, dict):
return None
xgptscore = 0
res = {}
for error in json_content['errors'].values():
if error['error_aspect'] is not None:
if ("xgptscore_" + error['error_aspect'] not in res):
res["xgptscore_" + error['error_aspect']] = 0
res["xgptscore_" + error['error_aspect']] -= error['score_reduction']
xgptscore -= error['score_reduction']
res["xgptscore"] = xgptscore
return res