-
Notifications
You must be signed in to change notification settings - Fork 1
/
mlm.py
186 lines (157 loc) · 7.3 KB
/
mlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import math
import logging
from datasets import load_dataset
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
is_torch_tpu_available,
set_seed,
)
from mlm_util import ModelArguments, DataTrainingArguments, ExtendedTrainingArguments
from util import is_main_process, init_logger, init_output_dir
from util import format_args, preprocess_logits_for_metrics
from mlm_util import get_metric_function, get_preprocess_function
from mlm_util import get_token2id_mapping, MyDataCollator
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ExtendedTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
log_level = training_args.get_process_log_level()
if is_main_process(training_args.local_rank):
init_output_dir(training_args)
with training_args.main_process_first(desc="getting logger"):
logger = init_logger(training_args, log_level)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
logger.info(format_args(training_args))
logger.info(format_args(data_args))
logger.info(format_args(model_args))
# Set seed before initializing model.
set_seed(training_args.seed)
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.train_file.split(".")[-1]
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None
)
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer.add_tokens("[none]")
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
token2id_mapping = get_token2id_mapping(tokenizer)
tokenize_function = get_preprocess_function(data_args, tokenizer, logger, token2id_mapping)
with training_args.main_process_first(desc="dataset map tokenization"):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset line_by_line",
)
if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
valid_dataset = tokenized_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(valid_dataset), data_args.max_eval_samples)
valid_dataset = valid_dataset.select(range(max_eval_samples))
test_dataset = tokenized_datasets["test"]
if data_args.max_eval_samples is not None:
max_test_samples = min(len(test_dataset), data_args.max_eval_samples)
test_dataset = test_dataset.select(range(max_test_samples))
compute_metrics = get_metric_function()
# Data collator
data_collator = MyDataCollator(tokenizer=tokenizer)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=valid_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
logger.info("*** Evaluate ***")
for eval_name, eval_dataset in zip(["valid", "test"], [valid_dataset, test_dataset]):
metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics = {key.replace("eval", eval_name): score for key, score in metrics.items()}
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None \
else len(eval_dataset)
metrics[f"{eval_name}_samples"] = min(max_eval_samples, len(eval_dataset))
try:
perplexity = math.exp(metrics[f"{eval_name}_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics(eval_name, metrics)
trainer.save_metrics(eval_name, metrics)
if __name__ == "__main__":
main()