Skip to content

Commit

Permalink
[dataloader] Fix text filtering bug and speed up spectrum length calc (
Browse files Browse the repository at this point in the history
…#216)

* [dataloader] Fix text filtering bug and speed up spectrum length calculation

* [fix] Fix code style check

---------

Co-authored-by: lsrami <[email protected]>
  • Loading branch information
lsrami and lsrami authored May 22, 2024
1 parent a410d66 commit 97b83e8
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 14 deletions.
7 changes: 7 additions & 0 deletions examples/aishell-3/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
$dataset_dir/data_aishell3 \
$data/all.txt

# Compute spec length (optional, but recommended)
python tools/compute_spec_length.py \
$data/all.txt \
$config \
$data/all_spec_length.txt
mv $data/all_spec_length.txt $data/all.txt

cat $data/all.txt | awk -F '|' '{print $2}' | \
sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt
echo 'sil 0' > $data/phones.txt
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ torchvision
tqdm
transformers
huggingface_hub
soundfile
72 changes: 72 additions & 0 deletions tools/compute_spec_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# author: @lsrami

import os
import sys
import json
from tqdm import tqdm
import soundfile as sf
from concurrent.futures import ThreadPoolExecutor


def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text


def process_item(item):
audiopath = item[0]
src_sampling_rate = sf.info(audiopath).samplerate
text = item[2]
text = text.strip().split()
if min_text_len <= len(text) and len(text) <= max_text_len:
length = int(os.path.getsize(audiopath) * sampling_rate /
src_sampling_rate) // (2 * hop_length)
item.append(length)
return item
else:
return None


def main(in_file, out_file):
"""
Filter text & store spec lengths
"""

audiopaths_sid_text = load_filepaths_and_text(in_file, split="|")

with ThreadPoolExecutor(max_workers=32) as executor:
results = list(
tqdm(
executor.map(process_item, audiopaths_sid_text),
total=len(audiopaths_sid_text),
)
)

# Filter out None results
results = [result for result in results if result is not None]

with open(out_file, "w", encoding="utf-8") as f:
for item in results:
f.write("|".join([str(i) for i in item]) + "\n")


if __name__ == "__main__":
if len(sys.argv) != 4:
print(f"Usage: {sys.argv[0]} <in_file> <config_file> <out_file>")
sys.exit(1)
in_file, config_file, out_file = sys.argv[1:4]

with open(config_file, "r", encoding="utf8") as f:
data = f.read()
config = json.loads(data)
hparams = config["data"]

min_text_len = hparams.get("min_text_len", 1)
max_text_len = hparams.get("max_text_len", 190)
sampling_rate = hparams.get("sampling_rate", 22050)
hop_length = hparams.get("hop_length", 256)
print(min_text_len, max_text_len, sampling_rate, hop_length)

main(in_file, out_file)
36 changes: 22 additions & 14 deletions wetts/vits/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torchaudio
import torch.utils.data
from tqdm import tqdm
import soundfile as sf

from utils.mel_processing import spectrogram_torch, mel_spectrogram_torch
from utils.task import load_filepaths_and_text
Expand Down Expand Up @@ -60,20 +62,26 @@ def _filter(self):
"""
Filter text & store spec lengths
"""
audiopaths_sid_text_new = []
lengths = []
for item in self.audiopaths_sid_text:
audiopath = item[0]
src_sampling_rate = torchaudio.info(audiopath).sample_rate
# filename|speaker|text
text = item[2]
if self.min_text_len <= len(text) and len(
text) <= self.max_text_len:
audiopaths_sid_text_new.append(item)
lengths.append(
int(
os.path.getsize(audiopath) * self.sampling_rate /
src_sampling_rate) // (2 * self.hop_length))
if len(self.audiopaths_sid_text[0]) > 3:
# spec length is provided
audiopaths_sid_text_new = [item[:3] for item in self.audiopaths_sid_text]
lengths = [int(item[3]) for item in self.audiopaths_sid_text]
else:
audiopaths_sid_text_new = []
lengths = []
for item in tqdm(self.audiopaths_sid_text, desc="Filtering data"):
audiopath = item[0]
src_sampling_rate = sf.info(audiopath).samplerate
# filename|speaker|text
text = item[2]
text = text.strip().split()
if self.min_text_len <= len(text) and len(
text) <= self.max_text_len:
audiopaths_sid_text_new.append(item)
lengths.append(
int(
os.path.getsize(audiopath) * self.sampling_rate /
src_sampling_rate) // (2 * self.hop_length))
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths

Expand Down

0 comments on commit 97b83e8

Please sign in to comment.