diff --git a/README.md b/README.md index fecdb13..43ff289 100755 --- a/README.md +++ b/README.md @@ -73,6 +73,30 @@ with fused residual and skip connections. python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6 ``` +## Finetuning the official checkpoint with your own data + +The "official" checkpoint above was trained using an older version of the code. +Therefore, you need to use `glow_old.py` to continue training from the official +checkpoint: + +1. Download our [published model] +2. Update the checkpoint to comply with recent code modifications: + + `python convert_model.py waveglow_old.pt waveglow_old_updated.pt` + +3. Perform steps 1 and 2 from the section above + +4. Set `"checkpoint_path": "./waveglow_old_updated.pt"` in `config.json` + +5. Train your WaveGlow networks with `OLD_GLOW=1` (not yet tested with + `distributed.py`) + + ```command + mkdir checkpoints + OLD_GLOW=1 python train.py -c config.json + ``` + + [//]: # (TODO) [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS) [pytorch 1.0]: https://github.com/pytorch/pytorch#installation diff --git a/glow.py b/glow.py index fc3a374..0516a0c 100644 --- a/glow.py +++ b/glow.py @@ -97,7 +97,7 @@ def forward(self, z, reverse=False): return z else: # Forward computation - log_det_W = batch_size * n_of_groups * torch.logdet(W) + log_det_W = batch_size * n_of_groups * torch.slogdet(W)[1] z = self.conv(z) return z, log_det_W diff --git a/glow_old.py b/glow_old.py index 0de2375..199f232 100644 --- a/glow_old.py +++ b/glow_old.py @@ -116,11 +116,9 @@ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, self.n_remaining_channels = n_remaining_channels # Useful during inference def forward(self, forward_input): - return None - """ - forward_input[0] = audio: batch x time - forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time """ spect, audio = forward_input @@ -135,19 +133,19 @@ def forward(self, forward_input): audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) output_audio = [] - s_list = [] - s_conv_list = [] + log_s_list = [] + log_det_W_list = [] for k in range(self.n_flows): - if k%4 == 0 and k > 0: - output_audio.append(audio[:,:self.n_multi,:]) - audio = audio[:,self.n_multi:,:] + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:,:self.n_early_size,:]) + audio = audio[:,self.n_early_size:,:] - # project to new basis - audio, s = self.convinv[k](audio) - s_conv_list.append(s) + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) n_half = int(audio.size(1)/2) + if k%2 == 0: audio_0 = audio[:,:n_half,:] audio_1 = audio[:,n_half:,:] @@ -155,19 +153,20 @@ def forward(self, forward_input): audio_1 = audio[:,:n_half,:] audio_0 = audio[:,n_half:,:] - output = self.nn[k]((audio_0, spect)) - s = output[:, n_half:, :] + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] b = output[:, :n_half, :] - audio_1 = torch.exp(s)*audio_1 + b - s_list.append(s) + audio_1 = torch.exp(log_s)*audio_1 + b + log_s_list.append(log_s) + + if k%2 != 0: + audio_0, audio_1 = audio_1, audio_0 + + audio = torch.cat([audio_0, audio_1],1) - if k%2 == 0: - audio = torch.cat([audio[:,:n_half,:], audio_1],1) - else: - audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) output_audio.append(audio) - return torch.cat(output_audio,1), s_list, s_conv_list - """ + return torch.cat(output_audio,1), log_s_list, log_det_W_list + def infer(self, spect, sigma=1.0): spect = self.upsample(spect) diff --git a/train.py b/train.py index 5d50b9b..3df9c06 100644 --- a/train.py +++ b/train.py @@ -35,14 +35,27 @@ #=====END: ADDED FOR DISTRIBUTED====== from torch.utils.data import DataLoader -from glow import WaveGlow, WaveGlowLoss +if os.getenv('OLD_GLOW') == '1': + print("Warning! Using old_glow.py instead of glow.py for training") + from glow_old import WaveGlow +else: + from glow import WaveGlow + +from glow import WaveGlowLoss from mel2samp import Mel2Samp def load_checkpoint(checkpoint_path, model, optimizer): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - iteration = checkpoint_dict['iteration'] - optimizer.load_state_dict(checkpoint_dict['optimizer']) + + if 'iteration' in checkpoint_dict: + iteration = checkpoint_dict['iteration'] + else: + iteration = 0 + + if 'optimizer' in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' (iteration {})" .format(