-
Notifications
You must be signed in to change notification settings - Fork 0
/
gg.py
executable file
·220 lines (196 loc) · 6.94 KB
/
gg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#!/usr/bin/env python3
import glob
import gzip
import os
import pathlib
import shutil
import struct
import subprocess
import json
import sys
from os.path import splitext, basename, exists, dirname
from collections import OrderedDict
repo_path = pathlib.Path(__file__).resolve().parent
def collect_nbts(out_path):
nbts = []
for path in glob.glob(out_path + '/**/*.nbt.gz', recursive=True):
prefix = path.split('.')[0]
prob_id = basename(path).split('.')[0]
ai_name = basename(dirname(path))
prob_src_path = str(repo_path / 'problemsF' / prob_id) + '_src.mdl'
prob_tgt_path = str(repo_path / 'problemsF' / prob_id) + '_tgt.mdl'
validate_path = prefix + '.validate'
javalidate_path = prefix + '.javalidate'
sc6_path = prefix + '.nbt.sc6'
r = 0
cost = 0
sc6_cost = 0
valid = None
javalid = None
step = 0
if not exists(prob_src_path):
prob_src_path = None
if not exists(prob_tgt_path):
prob_tgt_path = None
if prob_src_path:
with open(prob_src_path, 'rb') as f:
r = int.from_bytes(f.read(1), 'little')
else:
with open(prob_tgt_path, 'rb') as f:
r = int.from_bytes(f.read(1), 'little')
if exists(validate_path):
with open(validate_path, 'r') as f:
for s in f:
if s.startswith('Failure'):
valid = 0
if s.startswith('Success'):
valid = 1
if s.startswith('Time'):
step = int(s.split(' ')[-1].strip())
if s.startswith('Energy'):
cost = int(s.split(' ')[-1].strip())
if exists(javalidate_path):
with open(javalidate_path, 'r') as f:
try:
x = json.loads(f.read())
if x['result'] == 'success':
javalid = x['energy']
else:
javalid = 0
except:
javalid = 0
if exists(sc6_path):
with open(sc6_path, 'r') as f:
s = f.read().strip()
if s.isdigit():
sc6_cost = int(s)
nbts.append({
"path" : path,
"step" : step,
"prefix" : prefix,
"prob_id" : prob_id,
"ai_name" : ai_name,
"prob_src_path" : prob_src_path,
"prob_tgt_path" : prob_tgt_path,
"validate_path" : validate_path,
"javalidate_path" : javalidate_path,
"r" : r,
"cost" : cost,
"sc6_cost" : sc6_cost,
"valid" : valid,
"javalid" : javalid,
})
return nbts
def by_prob_map(nbts):
by_prob = {}
for nbt in nbts:
prob_id = nbt['prob_id']
if prob_id not in by_prob:
by_prob[prob_id] = []
by_prob[prob_id].append(nbt)
return by_prob
def gen_rank_tsv(nbts, output_path):
os.makedirs(output_path, exist_ok=True)
probs = by_prob_map(nbts)
for key in sorted(probs.keys()):
with open(output_path + '/' + key + '.tsv', 'w') as f:
f.write('\t'.join(["rank", "ai_name", "cost", "valid", "r", "path"]) + '\n')
probs[key].sort(key=lambda x : x['cost'])
rank = 1
for nbt in probs[key]:
if not nbt["valid"]:
continue
row = '\t'.join(map(str, [
rank,
nbt["ai_name"],
nbt["cost"],
nbt["valid"],
nbt["r"],
nbt["path"],
]))
f.write(row + '\n')
rank += 1
def find_bests(nbts):
probs = by_prob_map(nbts)
bests = {}
for key in sorted(probs.keys()):
probs[key].sort(key=lambda x : x['cost'])
for nbt in probs[key]:
print(nbt)
if nbt['valid']:
bests[key] = nbt
break
return bests
def find_java_bests(nbts, bests):
probs = by_prob_map(nbts)
for key in sorted(probs.keys()):
if key not in bests:
probs[key].sort(key=lambda x : x['javalid'] if x['javalid'] else 0)
for nbt in probs[key]:
if not nbt['javalid']:
continue
print(nbt)
if nbt['javalid']:
bests[key] = nbt
break
return bests
# not working !
def run_tracer(nbts):
for nbt in nbts:
if not exists(nbt['ascii_path']):
cmd = ['./bin/tracer4', str(nbt['r']), nbt['path'], nbt['ascii_path']]
subprocess.run(cmd)
def update_submission(nbts, mode):
os.makedirs(str(repo_path / 'submission/nbt/'), exist_ok=True)
bests = find_bests(nbts)
bests = find_java_bests(nbts, bests)
if mode == "hasi21":
for nbt in nbts:
if nbt['prob_id'] == 'FR115' and nbt['ai_name'] == 'hasi21':
bests['FR115'] = nbt
if nbt['prob_id'] == 'FR114' and nbt['ai_name'] == 'hasi21':
bests['FR114'] = nbt
if mode == "shioshiota19":
for nbt in nbts:
if nbt['prob_id'] == 'FR115' and nbt['ai_name'] == 'shioshiota19':
bests['FR115'] = nbt
if nbt['prob_id'] == 'FR114' and nbt['ai_name'] == 'shioshiota19':
bests['FR114'] = nbt
with open(str(repo_path / 'submission/list.tsv'), 'w') as f:
f.write('\t'.join(["prob_id", "ai_name", "cost", "valid", "javalid" ,"nbt_path"]) + '\n')
for nbt in sorted(bests.values(), key=lambda x: x['prob_id']):
print(nbt)
cmd = ['gzip', '-d', nbt['path'], '--keep', '--force']
print(' '.join(cmd))
subprocess.run(cmd)
src_path = splitext(nbt['path'])[0]
dst_path = str(repo_path / 'submission/nbt/' / (nbt['prob_id'] + '.nbt'))
print('move', src_path, dst_path)
shutil.move(src_path, dst_path)
nbt_path = os.path.relpath(nbt['path'], str(repo_path))
f.write('\t'.join(map(str, [nbt['prob_id'], nbt['ai_name'], nbt['cost'], nbt['valid'], nbt['javalid'], nbt_path])) + '\n')
def find_no_javalid(nbts):
i = 0
for nbt in nbts:
if nbt['javalid'] == None:
print(nbt['path'])
i += 1
if i > 100:
break
def main():
op = sys.argv[1]
out_path = sys.argv[2]
nbts = collect_nbts(out_path)
if op == 'update_submission':
if len(sys.argv) >= 4:
update_submission(nbts, sys.argv[3])
else:
update_submission(nbts)
elif op == 'gen_rank_tsv':
gen_rank_tsv(nbts, str(repo_path / 'rank'))
elif op == 'find_no_javalid':
find_no_javalid(nbts)
else:
print("no such command")
if __name__ == '__main__':
main()