-
Notifications
You must be signed in to change notification settings - Fork 41
/
utils.py
124 lines (110 loc) · 4.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import glob
import os
import numpy as np
import shutil
def write_results_no_score(filename, results):
"""Writes results in MOT style to filename."""
save_format = "{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n"
print("save results path:",filename)
with open(filename, "w") as f:
for frame_id, tlwhs, track_ids in results:
for tlwh, track_id in zip(tlwhs, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
line = save_format.format(
frame=frame_id,
id=track_id,
x1=round(x1, 1),
y1=round(y1, 1),
w=round(w, 1),
h=round(h, 1)
)
f.write(line)
def filter_targets(online_targets, aspect_ratio_thresh, min_box_area,dataset_):
online_tlwhs = []
online_ids = []
for t in online_targets:
tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
tid = t[4]
vertical = tlwh[2] / tlwh[3] > aspect_ratio_thresh
if dataset_ in ['BEE23', 'gmot']:
online_tlwhs.append(tlwh)
online_ids.append(tid)
else:
if tlwh[2] * tlwh[3] > min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
return online_tlwhs, online_ids
def dti(txt_path, save_path, n_min=30, n_dti=20):
def dti_write_results(filename, results):
save_format = "{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n"
with open(filename, "w") as f:
for i in range(results.shape[0]):
frame_data = results[i]
frame_id = int(frame_data[0])
track_id = int(frame_data[1])
x1, y1, w, h = frame_data[2:6]
line = save_format.format(
frame=frame_id, id=track_id, x1=x1, y1=y1, w=w, h=h, s=-1)
f.write(line)
seq_txts = sorted(glob.glob(os.path.join(txt_path, "*.txt")))
for seq_txt in seq_txts:
seq_name = seq_txt.replace("\\", "/").split("/")[-1]
seq_data = np.loadtxt(seq_txt, dtype=np.float64, delimiter=",")
min_id = int(np.min(seq_data[:, 1]))
max_id = int(np.max(seq_data[:, 1]))
seq_results = np.zeros((1, 10), dtype=np.float64)
for track_id in range(min_id, max_id + 1):
index = seq_data[:, 1] == track_id
tracklet = seq_data[index]
tracklet_dti = tracklet
if tracklet.shape[0] == 0:
continue
n_frame = tracklet.shape[0]
n_conf = np.sum(tracklet[:, 6] > 0.5)
if n_frame > n_min:
frames = tracklet[:, 0]
frames_dti = {}
for i in range(0, n_frame):
right_frame = frames[i]
if i > 0:
left_frame = frames[i - 1]
else:
left_frame = frames[i]
if 1 < right_frame - left_frame < n_dti:
num_bi = int(right_frame - left_frame - 1)
right_bbox = tracklet[i, 2:6]
left_bbox = tracklet[i - 1, 2:6]
for j in range(1, num_bi + 1):
curr_frame = j + left_frame
curr_bbox = (curr_frame - left_frame) * (right_bbox - left_bbox) / (
right_frame - left_frame
) + left_bbox
frames_dti[curr_frame] = curr_bbox
num_dti = len(frames_dti.keys())
if num_dti > 0:
data_dti = np.zeros((num_dti, 10), dtype=np.float64)
for n in range(num_dti):
data_dti[n, 0] = list(frames_dti.keys())[n]
data_dti[n, 1] = track_id
data_dti[n, 2:6] = frames_dti[list(
frames_dti.keys())[n]]
data_dti[n, 6:] = [1, -1, -1, -1]
tracklet_dti = np.vstack((tracklet, data_dti))
seq_results = np.vstack((seq_results, tracklet_dti))
save_seq_txt = os.path.join(save_path, seq_name)
seq_results = seq_results[1:]
seq_results = seq_results[seq_results[:, 0].argsort()]
dti_write_results(save_seq_txt, seq_results)
if __name__ == "__main__":
post_folder = "results/trackers/MOT17-val/1122_final_test_post"
pre_folder = "results/trackers/MOT17-val/1122_final_test"
if os.path.exists(post_folder):
print(f"Overwriting previous results in {post_folder}")
shutil.rmtree(post_folder)
shutil.copytree(pre_folder, post_folder)
post_folder_data = os.path.join(post_folder, "data")
dti(post_folder_data, post_folder_data)
print(
f"Linear interpolation post-processing applied, saved to {post_folder_data}.")