-
Notifications
You must be signed in to change notification settings - Fork 0
/
subs.py
299 lines (257 loc) · 9.49 KB
/
subs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import os
import re
import collections
from typing import TextIO
import pysubs2
from pathvalidate import sanitize_filename
import dicts
class SubEventWord:
def __init__(self, start=0.0, end=0.0, word='', score=0.0):
self.start = start
self.end = end
self.word = word
self.score = score
class SubEvent:
CLEAN_RE = re.compile(r"(.{2,}?)\1{2,}")
CLEAN_JA_RE = re.compile(r"([んうあ])\1{3,}")
CLEAN_ZH_RE = re.compile(r"([啊嗯唔咦哦])\1{2,}")
def __init__(self, start=0.0, end=0.0, text='', words: [SubEventWord] = None):
self.start = start
self.end = end
self.text = text.replace("\r\n", "\n").replace("\n", " ").strip()
self.line_count = 1
self.words = words
def __repr__(self):
return f"[{self.start}, {self.end}] {self.text}"
def __str__(self):
return f"[{self.start}, {self.end}] {self.text}"
def clean(self):
self.text = re.sub(self.CLEAN_RE, r"\1", self.text)
return self
def clean_ja(self):
self.clean()
self.text = re.sub(self.CLEAN_JA_RE, r"\1", self.text)
for old, new in dicts.whisper_ja_replace:
self.text = self.text.replace(old, new)
for old, new in dicts.whisper_ja_regex:
self.text = re.sub(old, new, self.text)
return self
def clean_zh(self, transcript):
self.clean()
self.text = re.sub(self.CLEAN_ZH_RE, r"\1", self.text)
for src, old, new in dicts.translate_zh_replace:
if src != "" and transcript.find(src) < 0:
continue
self.text = self.text.replace(old, new)
for src, old, new in dicts.translate_zh_regex:
if src is not None and re.match(src, transcript):
continue
self.text = re.sub(old, new, self.text)
return self
class Sub(collections.UserList):
def __init__(self, events: [SubEvent] = None):
if events is None:
events = []
super().__init__(events)
@staticmethod
def load_file(file):
ext = os.path.splitext(file)[1].lstrip(".").lower()
match ext:
case 'txt':
return Sub.load_txt(file)
case 'lrc':
return Sub.load_lrc(file)
case 'srt':
return Sub.load_pysubs2(file, ext)
case 'vtt':
return Sub.load_pysubs2(file, ext)
case _:
raise ValueError(
f"Unsupported file ext: {ext}"
)
@staticmethod
def load_pysubs2(file, format_: str = None):
ssa = pysubs2.load(file, format_=format_)
return Sub([SubEvent(start=event.start / 1000.0, end=event.end / 1000.0, text=event.text) for event in ssa])
@staticmethod
def load_lrc(file):
sub = Sub()
with open(file, mode='r', encoding='utf8') as f:
lines = f.readlines()
half_events = []
for i, segment in enumerate(lines):
start, text = segment.split("]", 1)
text = text.strip()
start = parse_lrc_timestamp(start.lstrip("["))
half_events.append((start, text))
half_events.sort(key=lambda x: x[0])
for i, segment in enumerate(half_events):
start, text = segment
if text == "":
continue
end = start + 10 # fallback
if i+1 < len(half_events):
end = half_events[i+1][0]
words = []
# TODO: parse A2 extension
sub.append(SubEvent(start=start, end=end, text=text, words=words))
return sub
@staticmethod
def load_txt(file):
with open(file, mode='r', encoding='utf8') as f:
lines = f.readlines()
return Sub([SubEvent(text=line) for line in lines if len(line.strip()) > 0])
@staticmethod
def from_transformer_whisper(whisper_result):
sub = Sub()
for idx, chunk in enumerate(whisper_result["chunks"]):
begin, end = chunk["timestamp"]
if begin is None:
if idx == 0:
begin = 0
else:
begin = whisper_result["chunks"][idx - 1]["timestamp"][1]
if end is None:
if idx != len(whisper_result["chunks"]) - 1:
end = whisper_result["chunks"][idx + 1]["timestamp"][0]
else:
end = begin + 10.0
sub.append(SubEvent(start=begin, end=end, text=chunk["text"]))
return merge_sub(sub)
@staticmethod
def from_fast_whisper(whisper_result):
sub = Sub([])
for idx, seg in enumerate(whisper_result["segments"]):
begin, end = seg["start"], seg["end"]
if begin is None:
if idx == 0:
begin = 0
else:
begin = whisper_result["segments"][idx - 1]["end"]
if end is None:
if idx != len(whisper_result["segments"]) - 1:
end = whisper_result["segments"][idx + 1]["begin"]
else:
end = begin + 10.0
words = []
if seg["words"]:
words = list(SubEventWord(**word) for word in seg["words"])
sub.append(SubEvent(start=begin, end=end, text=seg["text"], words=words))
return merge_sub(sub)
def merge_sub(sub: Sub) -> Sub:
# try merge
merged = Sub([])
i = 0
while i < len(sub):
if sub[i].text.strip() == '':
merged.append(sub[i])
i += 1
continue
start, end, text = sub[i].start, sub[i].end, sub[i].text
j = i+1
while j < len(sub):
if sub[j].text.startswith(text):
end, text = sub[j].end, sub[j].text
j += 1
continue
break
merged.append(SubEvent(start=start, end=end, text=text))
i = j
return merged
def write_all(sub: Sub, base_dir, filename, formats) -> list[str]:
writers = {
'lrc': write_lrc,
'srt': write_srt,
'vtt': write_vtt,
'txt': write_txt,
}
files = []
filename = sanitize_filename(filename)
if not formats or len(formats) == 0:
formats = ['lrc']
for fmt in formats:
writer = writers[fmt]
filepath = os.path.join(base_dir, f'{filename}.{fmt}')
if writer:
with open(filepath, "w", encoding='utf-8') as f:
writer(sub, f)
files.append(filepath)
return files
def write_vtt(sub: Sub, f: TextIO):
lines = ["WebVTT\n\n"]
for idx, event in enumerate(sub):
lines.append(f"{idx + 1}\n")
lines.append(f"{format_vtt_timestamp(event.start)} --> {format_vtt_timestamp(event.end)}\n")
lines.append(f"{event.text}\n\n")
f.writelines(lines)
def write_srt(sub: Sub, f: TextIO):
lines = []
for idx, event in enumerate(sub):
lines.append(f"{idx + 1}\n")
lines.append(f"{format_srt_timestamp(event.start)} --> {format_srt_timestamp(event.end)}\n")
lines.append(f"{event.text}\n\n")
f.writelines(lines)
def format_vtt_timestamp(seconds: float):
return format_timestamp(seconds, '.')
def format_srt_timestamp(seconds: float):
return format_timestamp(seconds, ',')
def format_timestamp(seconds: float, delim: str):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3600_000
milliseconds -= hours * 3600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
return (
f"{hours}{minutes:02d}:{seconds:02d}{delim}{milliseconds:03d}"
)
def format_lrc_timestamp(seconds: float):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
return (
f"{minutes:02d}:{seconds:02d}.{(milliseconds // 10):02d}"
)
def parse_lrc_timestamp(s: str):
def parse_int(v: str, prefix: str):
v = v.lstrip(prefix)
if v == "":
return 0
return int(v)
minutes_and_seconds, milliseconds = s.split(".", 1)
# ms
milliseconds = 10 * parse_int(milliseconds, "0")
# seconds
minutes_or_seconds = minutes_and_seconds.split(":")
minutes, seconds = 0, 0
match len(minutes_or_seconds):
case 1:
seconds = parse_int(minutes_or_seconds[0], "0")
case _:
minutes = parse_int(minutes_or_seconds[0], "0")
seconds = parse_int(minutes_or_seconds[1], "0")
return minutes*60 + seconds*1 + milliseconds*0.001
def write_lrc(sub: Sub, f: TextIO):
lines = []
for idx, event in enumerate(sub):
start_s = format_lrc_timestamp(event.start)
end_s = format_lrc_timestamp(event.end)
lines.append(f"[{start_s}]{event.text}\n")
if idx != len(sub) - 1:
next_start = sub[idx + 1].start
if next_start is not None:
next_start_s = format_lrc_timestamp(next_start)
if end_s == next_start_s:
continue
lines.append(f"[{end_s}]\n")
f.writelines(lines)
def write_txt(sub: Sub, f: TextIO):
lines = []
for event in sub:
lines.append(f"{event.text}\n")
f.writelines(lines)