Skip to content

Commit

Permalink
Refactor the code in train.py (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shengqiang-Li authored Mar 27, 2024
1 parent c6b6874 commit 4dd2794
Showing 1 changed file with 60 additions and 41 deletions.
101 changes: 60 additions & 41 deletions wetts/vits/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

import torch
import torch.distributed as dist

from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

Expand All @@ -30,30 +31,17 @@

def main():
hps = task.get_hparams()
# Set random seed
torch.manual_seed(hps.train.seed)
global global_step
# Initialize distributed
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
if rank == 0:
logger = task.get_logger(hps.model_dir)
logger.info(hps)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(
log_dir=os.path.join(hps.model_dir, "eval"))

if ("use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder):
print("Using mel posterior encoder for VITS2")
posterior_channels = hps.data.n_mel_channels # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False

# Get the dataset and data loader
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
Expand Down Expand Up @@ -85,6 +73,17 @@ def main():
collate_fn=collate_fn,
)

# Get the tts model
if ("use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder):
print("Using mel posterior encoder for VITS2")
posterior_channels = hps.data.n_mel_channels # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False

# some of these flags are not being used in the code and directly set in hps
# json file. they are kept here for reference and prototyping.
if ("use_transformer_flows" in hps.model.keys()
Expand Down Expand Up @@ -144,7 +143,7 @@ def main():
0.1,
gin_channels=hps.model.gin_channels
if hps.data.n_speakers != 0 else 0,
).cuda(rank)
).cuda(local_rank)
elif duration_discriminator_type == "dur_disc_2":
net_dur_disc = DurationDiscriminatorV2(
hps.model.hidden_channels,
Expand All @@ -153,7 +152,7 @@ def main():
0.1,
gin_channels=hps.model.gin_channels
if hps.data.n_speakers != 0 else 0,
).cuda(rank)
).cuda(local_rank)
else:
print("NOT using any duration discriminator like VITS1")
net_dur_disc = None
Expand All @@ -164,15 +163,33 @@ def main():
n_speakers=hps.data.n_speakers,
mas_noise_scale_initial=mas_noise_scale_initial,
noise_scale_delta=noise_scale_delta,
**hps.model).cuda(rank)
**hps.model).cuda(local_rank)
if ("use_mrd_disc" in hps.model.keys()
and hps.model.use_mrd_disc):
print("Using MultiPeriodMultiResolutionDiscriminator")
net_d = MultiPeriodMultiResolutionDiscriminator(
hps.model.use_spectral_norm).cuda(rank)
hps.model.use_spectral_norm).cuda(local_rank)
else:
print("Using MPD")
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)

# Dispatch the model from cpu to gpu
# comment - choihkk
# if we comment out unused parameter like DurationDiscriminator's
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
# self.post_transformer we don't have to set find_unused_parameters=True
# but I will not proceed with commenting out for compatibility with the
# latest work for others
net_g = DDP(net_g, device_ids=[local_rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[local_rank], find_unused_parameters=True)
if net_dur_disc:
net_dur_disc = DDP(
net_dur_disc,
device_ids=[local_rank],
find_unused_parameters=True
)

# Get the optimizer
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
Expand All @@ -195,17 +212,7 @@ def main():
else:
optim_dur_disc = None

# comment - choihkk
# if we comment out unused parameter like DurationDiscriminator's
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
# self.post_transformer we don't have to set find_unused_parameters=True
# but I will not proceed with commenting out for compatibility with the
# latest work for others
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if net_dur_disc:
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)

# Load the checkpoint
try:
_, _, _, epoch_str = task.load_checkpoint(
task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
Expand All @@ -224,6 +231,7 @@ def main():
epoch_str = 1
global_step = 0

# Get the scheduler
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
Expand All @@ -234,12 +242,22 @@ def main():
else:
scheduler_dur_disc = None

# Get the tensorboard summary
writer = None
if rank == 0:
logger = task.get_logger(hps.model_dir)
logger.info(hps)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(
log_dir=os.path.join(hps.model_dir, "eval"))

scaler = GradScaler(enabled=hps.train.fp16_run)

for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
Expand All @@ -253,6 +271,7 @@ def main():
else:
train_and_evaluate(
rank,
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
Expand All @@ -269,7 +288,7 @@ def main():
scheduler_dur_disc.step()


def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, scaler,
loaders, logger, writers):
net_g, net_d, net_dur_disc = nets
optim_g, optim_d, optim_dur_disc = optims
Expand Down Expand Up @@ -301,14 +320,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
net_g.module.noise_scale_delta * global_step)
net_g.module.current_mas_noise_scale = max(current_mas_noise_scale,
0.0)
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
rank, non_blocking=True)
x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda(
local_rank, non_blocking=True)
spec, spec_lengths = spec.cuda(
rank, non_blocking=True), spec_lengths.cuda(rank,
non_blocking=True)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True)
speakers = speakers.cuda(rank, non_blocking=True)
local_rank, non_blocking=True), spec_lengths.cuda(local_rank,
non_blocking=True)
y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda(
local_rank, non_blocking=True)
speakers = speakers.cuda(local_rank, non_blocking=True)

with autocast(enabled=hps.train.fp16_run):
(
Expand Down

0 comments on commit 4dd2794

Please sign in to comment.