diff --git a/wetts/vits/train.py b/wetts/vits/train.py index d59a81a..3e5e6b8 100644 --- a/wetts/vits/train.py +++ b/wetts/vits/train.py @@ -1,10 +1,11 @@ import os import torch +import torch.distributed as dist + from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import autocast, GradScaler @@ -30,30 +31,17 @@ def main(): hps = task.get_hparams() + # Set random seed torch.manual_seed(hps.train.seed) global global_step + # Initialize distributed world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) torch.cuda.set_device(local_rank) dist.init_process_group("nccl") - if rank == 0: - logger = task.get_logger(hps.model_dir) - logger.info(hps) - writer = SummaryWriter(log_dir=hps.model_dir) - writer_eval = SummaryWriter( - log_dir=os.path.join(hps.model_dir, "eval")) - - if ("use_mel_posterior_encoder" in hps.model.keys() - and hps.model.use_mel_posterior_encoder): - print("Using mel posterior encoder for VITS2") - posterior_channels = hps.data.n_mel_channels # vits2 - hps.data.use_mel_posterior_encoder = True - else: - print("Using lin posterior encoder for VITS1") - posterior_channels = hps.data.filter_length // 2 + 1 - hps.data.use_mel_posterior_encoder = False + # Get the dataset and data loader train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) train_sampler = DistributedBucketSampler( train_dataset, @@ -85,6 +73,17 @@ def main(): collate_fn=collate_fn, ) + # Get the tts model + if ("use_mel_posterior_encoder" in hps.model.keys() + and hps.model.use_mel_posterior_encoder): + print("Using mel posterior encoder for VITS2") + posterior_channels = hps.data.n_mel_channels # vits2 + hps.data.use_mel_posterior_encoder = True + else: + print("Using lin posterior encoder for VITS1") + posterior_channels = hps.data.filter_length // 2 + 1 + hps.data.use_mel_posterior_encoder = False + # some of these flags are not being used in the code and directly set in hps # json file. they are kept here for reference and prototyping. if ("use_transformer_flows" in hps.model.keys() @@ -144,7 +143,7 @@ def main(): 0.1, gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, - ).cuda(rank) + ).cuda(local_rank) elif duration_discriminator_type == "dur_disc_2": net_dur_disc = DurationDiscriminatorV2( hps.model.hidden_channels, @@ -153,7 +152,7 @@ def main(): 0.1, gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, - ).cuda(rank) + ).cuda(local_rank) else: print("NOT using any duration discriminator like VITS1") net_dur_disc = None @@ -164,15 +163,33 @@ def main(): n_speakers=hps.data.n_speakers, mas_noise_scale_initial=mas_noise_scale_initial, noise_scale_delta=noise_scale_delta, - **hps.model).cuda(rank) + **hps.model).cuda(local_rank) if ("use_mrd_disc" in hps.model.keys() and hps.model.use_mrd_disc): print("Using MultiPeriodMultiResolutionDiscriminator") net_d = MultiPeriodMultiResolutionDiscriminator( - hps.model.use_spectral_norm).cuda(rank) + hps.model.use_spectral_norm).cuda(local_rank) else: print("Using MPD") - net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank) + + # Dispatch the model from cpu to gpu + # comment - choihkk + # if we comment out unused parameter like DurationDiscriminator's + # self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's + # self.post_transformer we don't have to set find_unused_parameters=True + # but I will not proceed with commenting out for compatibility with the + # latest work for others + net_g = DDP(net_g, device_ids=[local_rank], find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[local_rank], find_unused_parameters=True) + if net_dur_disc: + net_dur_disc = DDP( + net_dur_disc, + device_ids=[local_rank], + find_unused_parameters=True + ) + + # Get the optimizer optim_g = torch.optim.AdamW( net_g.parameters(), hps.train.learning_rate, @@ -195,17 +212,7 @@ def main(): else: optim_dur_disc = None - # comment - choihkk - # if we comment out unused parameter like DurationDiscriminator's - # self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's - # self.post_transformer we don't have to set find_unused_parameters=True - # but I will not proceed with commenting out for compatibility with the - # latest work for others - net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) - net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) - if net_dur_disc: - net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True) - + # Load the checkpoint try: _, _, _, epoch_str = task.load_checkpoint( task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, @@ -224,6 +231,7 @@ def main(): epoch_str = 1 global_step = 0 + # Get the scheduler scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( @@ -234,12 +242,22 @@ def main(): else: scheduler_dur_disc = None + # Get the tensorboard summary + writer = None + if rank == 0: + logger = task.get_logger(hps.model_dir) + logger.info(hps) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter( + log_dir=os.path.join(hps.model_dir, "eval")) + scaler = GradScaler(enabled=hps.train.fp16_run) for epoch in range(epoch_str, hps.train.epochs + 1): if rank == 0: train_and_evaluate( rank, + local_rank, epoch, hps, [net_g, net_d, net_dur_disc], @@ -253,6 +271,7 @@ def main(): else: train_and_evaluate( rank, + local_rank, epoch, hps, [net_g, net_d, net_dur_disc], @@ -269,7 +288,7 @@ def main(): scheduler_dur_disc.step() -def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, +def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): net_g, net_d, net_dur_disc = nets optim_g, optim_d, optim_dur_disc = optims @@ -301,14 +320,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, net_g.module.noise_scale_delta * global_step) net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) - x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( - rank, non_blocking=True) + x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda( + local_rank, non_blocking=True) spec, spec_lengths = spec.cuda( - rank, non_blocking=True), spec_lengths.cuda(rank, - non_blocking=True) - y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( - rank, non_blocking=True) - speakers = speakers.cuda(rank, non_blocking=True) + local_rank, non_blocking=True), spec_lengths.cuda(local_rank, + non_blocking=True) + y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda( + local_rank, non_blocking=True) + speakers = speakers.cuda(local_rank, non_blocking=True) with autocast(enabled=hps.train.fp16_run): (