-
Notifications
You must be signed in to change notification settings - Fork 0
/
write_tfrecord.py
122 lines (105 loc) · 5.42 KB
/
write_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import tensorflow as tf
import os
feature_description = { # 定义Feature结构,告诉解码器每个Feature的类型是什么
'id': tf.io.FixedLenFeature([], tf.string),
'tag_id': tf.io.VarLenFeature(tf.int64),
'category_id': tf.io.FixedLenFeature([], tf.int64),
'title': tf.io.FixedLenFeature([], tf.string),
'asr_text': tf.io.FixedLenFeature([], tf.string),
'frame_feature': tf.io.VarLenFeature(tf.string)
}
def read_and_decode(example_string):
'''
从TFrecord格式文件中读取数据 train
'''
feature_dict = tf.io.parse_single_example(example_string, feature_description)
frame_feature = tf.sparse.to_dense(feature_dict['frame_feature']).numpy()
title = feature_dict['title'].numpy()
asr_text = feature_dict['asr_text'].numpy()
id = feature_dict['id'].numpy()
tag_id = tf.sparse.to_dense(feature_dict['tag_id']).numpy()
category_id = feature_dict['category_id'].numpy()
return id, tag_id, category_id, frame_feature, title, asr_text
import glob
def get_all_data(path): # 'data/pairwise'
filenames = glob.glob(path)
print(filenames)
dataset = tf.data.TFRecordDataset(filenames)
datas = {}
for i, data in enumerate(dataset):
id, tag_id, category_id, frame_feature, title, asr_text = read_and_decode(data)
id = id.decode()
datas[id] = {'tag_id': tag_id, 'category_id': category_id, 'frame_feature': frame_feature, 'title': title, 'asr_text': asr_text}
# print(id)
# print(datas['2345203561710400875']['asr_text'])
# break
# if i % 10000 == 0 and i > 0:
# break
return datas
datas = get_all_data('data/pairwise/pairwise.tfrecords')
label_path = 'data/pairwise/label.tsv'
f = open(label_path)
all_pair_data = []
for line in f:
id_1, id_2, sim = line.strip().split('\t')
sim = float(sim)
all_pair_data.append([id_1, id_2, sim])
# label_path_sup = 'data/pairwise/label_sup_sam.tsv'
# f = open(label_path_sup)
# all_pair_data_sup = []
# for line in f:
# id_1, id_2, sim = line.strip().split('\t')
# sim = float(sim)
# all_pair_data_sup.append([id_1, id_2, sim])
# shuffle pair data and get the top 6000 for validation
import random
random.seed(42)
# print(all_pair_data[:10])
random.shuffle(all_pair_data)
# print(all_pair_data[:10])
save_path = {0: '0-5999val', 1: '6000-11999val',2: '12000-17999val',3: '18000-23999val',4: '24000-29999val', 5: '30000-35999val',
6: '36000-41999val',7: '42000-47999val',8: '48000-53999val',9: '54000-59999val',10: '60000-65999val'}
for i in range(11):
start, end = i*6000, (i+1)*6000
val_pair_data = all_pair_data[start:end]
train_pair_data = all_pair_data[:start]+all_pair_data[end:]
random.shuffle(train_pair_data)
from tqdm import tqdm
def write_tfrecord(pair_datas, split):
if not os.path.exists('data/pairwise/'+save_path[i]):
os.mkdir('data/pairwise/'+save_path[i])
write_path = 'data/pairwise/'+save_path[i]+'/'+split+'.tfrecord' # 61899
writer = tf.io.TFRecordWriter(write_path)
for pair_data in tqdm(pair_datas): # [id_1, id_2, sim] [str, str, float]
id_1, id_2, sim = pair_data
tag_id_1 = datas[id_1]['tag_id']
category_id_1 = datas[id_1]['category_id']
frame_feature_1 = datas[id_1]['frame_feature'].tolist()
title_1 = datas[id_1]['title']
asr_text_1 = datas[id_1]['asr_text']
tag_id_2 = datas[id_2]['tag_id']
category_id_2 = datas[id_2]['category_id']
frame_feature_2 = datas[id_2]['frame_feature'].tolist()
title_2 = datas[id_2]['title']
asr_text_2 = datas[id_2]['asr_text']
feature = { # 建立 tf.train.Feature 字典
'id_1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(id_1.encode())])),
'tag_id_1': tf.train.Feature(int64_list=tf.train.Int64List(value=list(tag_id_1))),
'frame_feature_1': tf.train.Feature(bytes_list=tf.train.BytesList(value=frame_feature_1)),
'category_id_1': tf.train.Feature(int64_list=tf.train.Int64List(value=[category_id_1])),
'title_1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[title_1])),
'asr_text_1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[asr_text_1])),
'id_2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(id_2.encode())])),
'tag_id_2': tf.train.Feature(int64_list=tf.train.Int64List(value=list(tag_id_2))),
'frame_feature_2': tf.train.Feature(bytes_list=tf.train.BytesList(value=frame_feature_2)),
'category_id_2': tf.train.Feature(int64_list=tf.train.Int64List(value=[category_id_2])),
'title_2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[title_2])),
'asr_text_2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[asr_text_2])),
'sim': tf.train.Feature(float_list=tf.train.FloatList(value=[sim]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
print('write %d th fold.' % i)
write_tfrecord(val_pair_data, 'val')
write_tfrecord(train_pair_data, 'train')