-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlvae.py
123 lines (90 loc) · 5.13 KB
/
mlvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
from torch.nn import init
from collections import OrderedDict
import torch.nn.functional as F
from utils import reparameterize, group_wise_reparameterize, accumulate_group_evidence
from model.blocks import LinearNorm
from torch.autograd import Variable
# MLVAE network
class MLVAENet(nn.Module):
def __init__(self,model_config):
super(MLVAENet,self).__init__()
self.linear_model = nn.Sequential(OrderedDict([
('linear_1', nn.Linear(in_features=256, out_features=256, bias=True)),
('tan_h_1', nn.Tanh())
]))
# style
self.style_mu = LinearNorm(
256,
model_config["speaker_encoder"]["z_dim"])
self.style_logvar = LinearNorm(
256,
model_config["speaker_encoder"]["z_dim"])
# class
self.class_mu = LinearNorm(
256,
model_config["accent_encoder"]["z_dim"])
self.class_logvar = LinearNorm(
256,
model_config["accent_encoder"]["z_dim"])
# self.accent_class = nn.Sequential(
# LinearNorm(
# model_config["speaker_encoder"]["z_dim"],
# 32),
# nn.ReLU(),
# LinearNorm(32,
# model_config["accent_encoder"]["n_classes"])
# )
# self.speaker_class = nn.Sequential(
# LinearNorm(
# model_config["accent_encoder"]["z_dim"],
# 32),
# nn.ReLU(),
# LinearNorm(32,
# model_config["speaker_encoder"]["n_classes"])
# )
# weight initialization
for m in self.modules():
if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
torch.nn.init.xavier_normal_(m.weight)
if m.bias.data is not None:
init.constant_(m.bias, 0)
def forward(self, x, labels_batch):
x = x.view(x.size(0), -1)
x = self.linear_model(x)
style_latent_space_mu = self.style_mu(x)
style_latent_space_logvar = self.style_logvar(x)
class_latent_space_mu = self.class_mu(x)
class_latent_space_logvar = self.class_logvar(x)
accent_latent_space_mu = class_latent_space_mu.clone()
accent_latent_space_logvar = class_latent_space_logvar.clone()
grouped_mu, grouped_logvar = accumulate_group_evidence(class_latent_space_mu.data, class_latent_space_logvar.data, labels_batch,is_cuda=True)
style_latent_embeddings = reparameterize(training=True, mu=style_latent_space_mu, logvar=style_latent_space_logvar)
accent_latent_embeddings = reparameterize(training=True, mu=accent_latent_space_mu, logvar=accent_latent_space_logvar)
class_latent_embeddings = group_wise_reparameterize(training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch,is_cuda=True)
# acc_prob = self.accent_class(style_latent_embeddings) #acc classifier on spk embs, for adv loss
# spk_prob = self.speaker_class(class_latent_embeddings) #spk classifier on acc embs,
# z = torch.cat((style_latent_embeddings, class_latent_embeddings), dim=1)
return (class_latent_embeddings, style_latent_embeddings, accent_latent_embeddings, (grouped_mu, grouped_logvar, style_latent_space_mu, style_latent_space_logvar, accent_latent_space_mu, accent_latent_space_logvar))
def inference(self,x,labels_batch):
x = x.view(x.size(0),-1)
x = self.linear_model(x)
style_latent_space_mu = self.style_mu(x)
style_latent_space_logvar = self.style_logvar(x)
class_latent_space_mu = self.class_mu(x)
class_latent_space_logvar = self.class_logvar(x)
accent_latent_space_mu = class_latent_space_mu.clone()
accent_latent_space_logvar = class_latent_space_logvar.clone()
grouped_mu, grouped_logvar = accumulate_group_evidence(
class_latent_space_mu.data, class_latent_space_logvar.data, labels_batch,is_cuda=True)
style_latent_embeddings = reparameterize(
training=False, mu=style_latent_space_mu, logvar=style_latent_space_logvar)
accent_latent_embeddings = reparameterize(training=False, mu=accent_latent_space_mu, logvar=accent_latent_space_logvar)
class_latent_embeddings = group_wise_reparameterize(
training=False, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch,is_cuda=True)
# cat_prob = self.categorical_layer(class_latent_embeddings)
# acc_prob = self.accent_class(style_latent_embeddings) #acc classifier on spk embs, for adv loss
# spk_prob = self.speaker_class(class_latent_embeddings) #spk classifier on acc embs,
# z = torch.cat((style_latent_embeddings, class_latent_embeddings), dim=1)
return (class_latent_embeddings, style_latent_embeddings, accent_latent_embeddings, (grouped_mu, grouped_logvar, style_latent_space_mu, style_latent_space_logvar, accent_latent_space_mu, accent_latent_space_logvar))