Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to continue training from the official checkpoint #99

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ with fused residual and skip connections.
python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6
```

## Finetuning the official checkpoint with your own data

The "official" checkpoint above was trained using an older version of the code.
Therefore, you need to use `glow_old.py` to continue training from the official
checkpoint:

1. Download our [published model]
2. Update the checkpoint to comply with recent code modifications:

`python convert_model.py waveglow_old.pt waveglow_old_updated.pt`

3. Perform steps 1 and 2 from the section above

4. Set `"checkpoint_path": "./waveglow_old_updated.pt"` in `config.json`

5. Train your WaveGlow networks with `OLD_GLOW=1` (not yet tested with
`distributed.py`)

```command
mkdir checkpoints
OLD_GLOW=1 python train.py -c config.json
```


[//]: # (TODO)
[//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS)
[pytorch 1.0]: https://github.com/pytorch/pytorch#installation
Expand Down
2 changes: 1 addition & 1 deletion glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, z, reverse=False):
return z
else:
# Forward computation
log_det_W = batch_size * n_of_groups * torch.logdet(W)
log_det_W = batch_size * n_of_groups * torch.slogdet(W)[1]
z = self.conv(z)
return z, log_det_W

Expand Down
43 changes: 21 additions & 22 deletions glow_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
self.n_remaining_channels = n_remaining_channels # Useful during inference

def forward(self, forward_input):
return None
"""
forward_input[0] = audio: batch x time
forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time
"""
forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
forward_input[1] = audio: batch x time
"""
spect, audio = forward_input

Expand All @@ -135,39 +133,40 @@ def forward(self, forward_input):

audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
output_audio = []
s_list = []
s_conv_list = []
log_s_list = []
log_det_W_list = []

for k in range(self.n_flows):
if k%4 == 0 and k > 0:
output_audio.append(audio[:,:self.n_multi,:])
audio = audio[:,self.n_multi:,:]
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:,:self.n_early_size,:])
audio = audio[:,self.n_early_size:,:]

# project to new basis
audio, s = self.convinv[k](audio)
s_conv_list.append(s)
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)

n_half = int(audio.size(1)/2)

if k%2 == 0:
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
else:
audio_1 = audio[:,:n_half,:]
audio_0 = audio[:,n_half:,:]

output = self.nn[k]((audio_0, spect))
s = output[:, n_half:, :]
output = self.WN[k]((audio_0, spect))
log_s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = torch.exp(s)*audio_1 + b
s_list.append(s)
audio_1 = torch.exp(log_s)*audio_1 + b
log_s_list.append(log_s)

if k%2 != 0:
audio_0, audio_1 = audio_1, audio_0

audio = torch.cat([audio_0, audio_1],1)

if k%2 == 0:
audio = torch.cat([audio[:,:n_half,:], audio_1],1)
else:
audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
output_audio.append(audio)
return torch.cat(output_audio,1), s_list, s_conv_list
"""
return torch.cat(output_audio,1), log_s_list, log_det_W_list


def infer(self, spect, sigma=1.0):
spect = self.upsample(spect)
Expand Down
19 changes: 16 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,27 @@
#=====END: ADDED FOR DISTRIBUTED======

from torch.utils.data import DataLoader
from glow import WaveGlow, WaveGlowLoss
if os.getenv('OLD_GLOW') == '1':
print("Warning! Using old_glow.py instead of glow.py for training")
from glow_old import WaveGlow
else:
from glow import WaveGlow

from glow import WaveGlowLoss
from mel2samp import Mel2Samp

def load_checkpoint(checkpoint_path, model, optimizer):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
optimizer.load_state_dict(checkpoint_dict['optimizer'])

if 'iteration' in checkpoint_dict:
iteration = checkpoint_dict['iteration']
else:
iteration = 0

if 'optimizer' in checkpoint_dict:
optimizer.load_state_dict(checkpoint_dict['optimizer'])

model_for_loading = checkpoint_dict['model']
model.load_state_dict(model_for_loading.state_dict())
print("Loaded checkpoint '{}' (iteration {})" .format(
Expand Down