mirror of https://github.com/coqui-ai/TTS.git
update separate stopnet flow to make it faster.
This commit is contained in:
parent
788c8100ba
commit
e62659da94
|
@ -302,13 +302,14 @@ class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, memory_size,
|
def __init__(self, in_features, memory_dim, r, memory_size,
|
||||||
attn_windowing, attn_norm):
|
attn_windowing, attn_norm, separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.max_decoder_steps = 500
|
self.max_decoder_steps = 500
|
||||||
self.memory_size = memory_size if memory_size > 0 else r
|
self.memory_size = memory_size if memory_size > 0 else r
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
|
self.separate_stopnet = separate_stopnet
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
memory_dim * self.memory_size, out_features=[256, 128])
|
memory_dim * self.memory_size, out_features=[256, 128])
|
||||||
|
@ -415,6 +416,9 @@ class Decoder(nn.Module):
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||||
del decoder_output
|
del decoder_output
|
||||||
|
if self.separate_stopnet:
|
||||||
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
return output, stop_token, self.attention
|
return output, stop_token, self.attention
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ class LinearBN(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.linear_layer(x)
|
out = self.linear_layer(x)
|
||||||
if len(out.shape)==3:
|
if len(out.shape) == 3:
|
||||||
out = out.permute(1, 2, 0)
|
out = out.permute(1, 2, 0)
|
||||||
out = self.bn(out)
|
out = self.bn(out)
|
||||||
if len(out.shape) == 3:
|
if len(out.shape) == 3:
|
||||||
|
@ -53,7 +53,11 @@ class LinearBN(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
def __init__(self, in_features, prenet_type, prenet_dropout, out_features=[256, 256]):
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
prenet_type,
|
||||||
|
prenet_dropout,
|
||||||
|
out_features=[256, 256]):
|
||||||
super(Prenet, self).__init__()
|
super(Prenet, self).__init__()
|
||||||
self.prenet_type = prenet_type
|
self.prenet_type = prenet_type
|
||||||
self.prenet_dropout = prenet_dropout
|
self.prenet_dropout = prenet_dropout
|
||||||
|
@ -64,8 +68,8 @@ class Prenet(nn.Module):
|
||||||
for (in_size, out_size) in zip(in_features, out_features)
|
for (in_size, out_size) in zip(in_features, out_features)
|
||||||
])
|
])
|
||||||
elif prenet_type == "original":
|
elif prenet_type == "original":
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([
|
||||||
[Linear(in_size, out_size, bias=False)
|
Linear(in_size, out_size, bias=False)
|
||||||
for (in_size, out_size) in zip(in_features, out_features)
|
for (in_size, out_size) in zip(in_features, out_features)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -121,9 +125,10 @@ class LocationLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention,
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||||
attention_location_n_filters, attention_location_kernel_size,
|
location_attention, attention_location_n_filters,
|
||||||
windowing, norm, forward_attn, trans_agent):
|
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||||
|
trans_agent):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.query_layer = Linear(
|
self.query_layer = Linear(
|
||||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
|
@ -131,10 +136,11 @@ class Attention(nn.Module):
|
||||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.v = Linear(attention_dim, 1, bias=True)
|
self.v = Linear(attention_dim, 1, bias=True)
|
||||||
if trans_agent:
|
if trans_agent:
|
||||||
self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True)
|
self.ta = nn.Linear(
|
||||||
|
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||||
if location_attention:
|
if location_attention:
|
||||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
self.location_layer = LocationLayer(
|
||||||
attention_location_kernel_size,
|
attention_location_n_filters, attention_location_kernel_size,
|
||||||
attention_dim)
|
attention_dim)
|
||||||
self._mask_value = -float("inf")
|
self._mask_value = -float("inf")
|
||||||
self.windowing = windowing
|
self.windowing = windowing
|
||||||
|
@ -152,7 +158,9 @@ class Attention(nn.Module):
|
||||||
def init_forward_attn(self, inputs):
|
def init_forward_attn(self, inputs):
|
||||||
B = inputs.shape[0]
|
B = inputs.shape[0]
|
||||||
T = inputs.shape[1]
|
T = inputs.shape[1]
|
||||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
|
self.alpha = torch.cat(
|
||||||
|
[torch.ones([B, 1]),
|
||||||
|
torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
|
||||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||||
|
|
||||||
def init_location_attention(self, inputs):
|
def init_location_attention(self, inputs):
|
||||||
|
@ -188,8 +196,7 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
def get_attention(self, query, processed_inputs):
|
def get_attention(self, query, processed_inputs):
|
||||||
processed_query = self.query_layer(query.unsqueeze(1))
|
processed_query = self.query_layer(query.unsqueeze(1))
|
||||||
energies = self.v(
|
energies = self.v(torch.tanh(processed_query + processed_inputs))
|
||||||
torch.tanh(processed_query +processed_inputs))
|
|
||||||
energies = energies.squeeze(-1)
|
energies = energies.squeeze(-1)
|
||||||
return energies, processed_query
|
return energies, processed_query
|
||||||
|
|
||||||
|
@ -210,8 +217,10 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
def apply_forward_attention(self, inputs, alignment, query):
|
def apply_forward_attention(self, inputs, alignment, query):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
|
(1, 0, 0, 0)).to(inputs.device)
|
||||||
|
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||||
|
self.u * prev_alpha) + 1e-8) * alignment
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||||
# compute context
|
# compute context
|
||||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||||
|
@ -222,8 +231,7 @@ class Attention(nn.Module):
|
||||||
self.u = torch.sigmoid(self.ta(ta_input))
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
return context, self.alpha
|
return context, self.alpha
|
||||||
|
|
||||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
||||||
mask):
|
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
attention, processed_query = self.get_location_attention(
|
attention, processed_query = self.get_location_attention(
|
||||||
attention_hidden_state, processed_inputs)
|
attention_hidden_state, processed_inputs)
|
||||||
|
@ -248,7 +256,8 @@ class Attention(nn.Module):
|
||||||
self.update_location_attention(alignment)
|
self.update_location_attention(alignment)
|
||||||
# apply forward attention if enabled
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state)
|
context, self.attention_weights = self.apply_forward_attention(
|
||||||
|
inputs, alignment, attention_hidden_state)
|
||||||
else:
|
else:
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
|
@ -321,13 +330,17 @@ class Encoder(nn.Module):
|
||||||
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn):
|
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
||||||
|
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||||
|
location_attn, separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mel_channels = inputs_dim
|
self.mel_channels = inputs_dim
|
||||||
self.r = r
|
self.r = r
|
||||||
self.encoder_embedding_dim = in_features
|
self.encoder_embedding_dim = in_features
|
||||||
|
self.separate_stopnet = separate_stopnet
|
||||||
self.attention_rnn_dim = 1024
|
self.attention_rnn_dim = 1024
|
||||||
self.decoder_rnn_dim = 1024
|
self.decoder_rnn_dim = 1024
|
||||||
self.prenet_dim = 256
|
self.prenet_dim = 256
|
||||||
|
@ -336,14 +349,16 @@ class Decoder(nn.Module):
|
||||||
self.p_attention_dropout = 0.1
|
self.p_attention_dropout = 0.1
|
||||||
self.p_decoder_dropout = 0.1
|
self.p_decoder_dropout = 0.1
|
||||||
|
|
||||||
self.prenet = Prenet(self.mel_channels * r, prenet_type, prenet_dropout,
|
self.prenet = Prenet(self.mel_channels * r, prenet_type,
|
||||||
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim])
|
[self.prenet_dim, self.prenet_dim])
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||||
self.attention_rnn_dim)
|
self.attention_rnn_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn,
|
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||||
32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
128, location_attn, 32, 31, attn_win,
|
||||||
|
attn_norm, forward_attn, trans_agent)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
@ -353,7 +368,8 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
self.stopnet = nn.Sequential(
|
self.stopnet = nn.Sequential(
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
Linear(self.decoder_rnn_dim + self.mel_channels * r,
|
Linear(
|
||||||
|
self.decoder_rnn_dim + self.mel_channels * r,
|
||||||
1,
|
1,
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
@ -401,8 +417,7 @@ class Decoder(nn.Module):
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
stop_tokens = stop_tokens.contiguous()
|
stop_tokens = stop_tokens.contiguous()
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
outputs = outputs.view(
|
outputs = outputs.view(outputs.size(0), -1, self.mel_channels)
|
||||||
outputs.size(0), -1, self.mel_channels)
|
|
||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
|
@ -415,12 +430,10 @@ class Decoder(nn.Module):
|
||||||
self.attention_cell = F.dropout(
|
self.attention_cell = F.dropout(
|
||||||
self.attention_cell, self.p_attention_dropout, self.training)
|
self.attention_cell, self.p_attention_dropout, self.training)
|
||||||
|
|
||||||
self.context = self.attention_layer(
|
self.context = self.attention_layer(self.attention_hidden, self.inputs,
|
||||||
self.attention_hidden, self.inputs, self.processed_inputs,
|
self.processed_inputs, self.mask)
|
||||||
self.mask)
|
|
||||||
|
|
||||||
memory = torch.cat(
|
memory = torch.cat((self.attention_hidden, self.context), -1)
|
||||||
(self.attention_hidden, self.context), -1)
|
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
memory, (self.decoder_hidden, self.decoder_cell))
|
memory, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||||
|
@ -428,16 +441,18 @@ class Decoder(nn.Module):
|
||||||
self.decoder_cell = F.dropout(self.decoder_cell,
|
self.decoder_cell = F.dropout(self.decoder_cell,
|
||||||
self.p_decoder_dropout, self.training)
|
self.p_decoder_dropout, self.training)
|
||||||
|
|
||||||
decoder_hidden_context = torch.cat(
|
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||||
(self.decoder_hidden, self.context), dim=1)
|
dim=1)
|
||||||
|
|
||||||
decoder_output = self.linear_projection(
|
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||||
decoder_hidden_context)
|
|
||||||
|
|
||||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||||
|
|
||||||
gate_prediction = self.stopnet(stopnet_input)
|
if self.separate_stopnet:
|
||||||
return decoder_output, gate_prediction, self.attention_layer.attention_weights
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
|
else:
|
||||||
|
stop_token = self.stopnet(stopnet_input)
|
||||||
|
return decoder_output, stop_token, self.attention_layer.attention_weights
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask):
|
def forward(self, inputs, memories, mask):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
|
@ -451,8 +466,7 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
memory = memories[len(outputs)]
|
memory = memories[len(outputs)]
|
||||||
mel_output, stop_token, attention_weights = self.decode(
|
mel_output, stop_token, attention_weights = self.decode(memory)
|
||||||
memory)
|
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token.squeeze(1)]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
@ -481,7 +495,8 @@ class Decoder(nn.Module):
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
||||||
|
and t > inputs.shape[1])
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
|
@ -523,7 +538,8 @@ class Decoder(nn.Module):
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5
|
||||||
|
and t > inputs.shape[1])
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
|
@ -541,7 +557,6 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
|
|
||||||
def inference_step(self, inputs, t, memory=None):
|
def inference_step(self, inputs, t, memory=None):
|
||||||
"""
|
"""
|
||||||
For debug purposes
|
For debug purposes
|
||||||
|
|
|
@ -15,7 +15,8 @@ class Tacotron(nn.Module):
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
memory_size=5,
|
memory_size=5,
|
||||||
attn_win=False,
|
attn_win=False,
|
||||||
attn_norm="sigmoid"):
|
attn_norm="sigmoid",
|
||||||
|
separate_stopnet=True):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
|
@ -23,7 +24,8 @@ class Tacotron(nn.Module):
|
||||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, attn_norm)
|
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
||||||
|
attn_norm, separate_stopnet)
|
||||||
self.postnet = PostCBHG(mel_dim)
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Sequential(
|
self.last_linear = nn.Sequential(
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||||
|
|
|
@ -9,7 +9,17 @@ from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
# TODO: match function arguments with tacotron
|
# TODO: match function arguments with tacotron
|
||||||
class Tacotron2(nn.Module):
|
class Tacotron2(nn.Module):
|
||||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, location_attn=True):
|
def __init__(self,
|
||||||
|
num_chars,
|
||||||
|
r,
|
||||||
|
attn_win=False,
|
||||||
|
attn_norm="softmax",
|
||||||
|
prenet_type="original",
|
||||||
|
prenet_dropout=True,
|
||||||
|
forward_attn=False,
|
||||||
|
trans_agent=False,
|
||||||
|
location_attn=True,
|
||||||
|
separate_stopnet=True):
|
||||||
super(Tacotron2, self).__init__()
|
super(Tacotron2, self).__init__()
|
||||||
self.n_mel_channels = 80
|
self.n_mel_channels = 80
|
||||||
self.n_frames_per_step = r
|
self.n_frames_per_step = r
|
||||||
|
@ -18,7 +28,10 @@ class Tacotron2(nn.Module):
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(512)
|
||||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn)
|
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
||||||
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
|
forward_attn, trans_agent, location_attn,
|
||||||
|
separate_stopnet)
|
||||||
self.postnet = Postnet(self.n_mel_channels)
|
self.postnet = Postnet(self.n_mel_channels)
|
||||||
|
|
||||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
@ -50,14 +63,14 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def inference_truncated(self, text):
|
def inference_truncated(self, text):
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
Preserve model states for continuous inference
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs)
|
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
|
||||||
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||||
|
|
|
@ -57,6 +57,15 @@ def test_phoneme_to_sequence():
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
|
||||||
|
# padding char
|
||||||
|
text = "_Be a _voice, not an! echo_"
|
||||||
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
|
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||||
|
print(text_hat)
|
||||||
|
print(len(sequence))
|
||||||
|
assert text_hat == gt
|
||||||
|
|
||||||
|
|
||||||
def test_text2phone():
|
def test_text2phone():
|
||||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||||
|
|
4
train.py
4
train.py
|
@ -141,10 +141,6 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if not c.separate_stopnet and c.stopnet:
|
if not c.separate_stopnet and c.stopnet:
|
||||||
loss += stop_loss
|
loss += stop_loss
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
|
||||||
if c.separate_stopnet:
|
|
||||||
loss.backward(retain_graph=True)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
|
|
|
@ -253,7 +253,8 @@ def setup_model(num_chars, c):
|
||||||
r=c.r,
|
r=c.r,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
memory_size=c.memory_size)
|
memory_size=c.memory_size,
|
||||||
|
separate_stopnet=c.separate_stopnet)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
|
@ -264,5 +265,6 @@ def setup_model(num_chars, c):
|
||||||
prenet_dropout=c.prenet_dropout,
|
prenet_dropout=c.prenet_dropout,
|
||||||
forward_attn=c.use_forward_attn,
|
forward_attn=c.use_forward_attn,
|
||||||
trans_agent=c.transition_agent,
|
trans_agent=c.transition_agent,
|
||||||
location_attn=c.location_attn)
|
location_attn=c.location_attn,
|
||||||
|
separate_stopnet=c.separate_stopnet)
|
||||||
return model
|
return model
|
|
@ -136,8 +136,8 @@ def _arpabet_to_sequence(text):
|
||||||
|
|
||||||
|
|
||||||
def _should_keep_symbol(s):
|
def _should_keep_symbol(s):
|
||||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
return s in _symbol_to_id and s not in ['~', '^', '_']
|
||||||
|
|
||||||
|
|
||||||
def _should_keep_phoneme(p):
|
def _should_keep_phoneme(p):
|
||||||
return p in _phonemes_to_id and p is not '_' and p is not '~'
|
return p in _phonemes_to_id and p not in ['~', '^', '_']
|
||||||
|
|
Loading…
Reference in New Issue