mirror of https://github.com/coqui-ai/TTS.git
Pass mask instead of length to model
This commit is contained in:
parent
2196ce9eba
commit
e0bce1d2d1
|
@ -105,13 +105,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
||||||
.format(align_model))
|
.format(align_model))
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, memory, context, rnn_state, annots, atten, mask):
|
||||||
memory,
|
|
||||||
context,
|
|
||||||
rnn_state,
|
|
||||||
annots,
|
|
||||||
atten,
|
|
||||||
annot_lens=None):
|
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- memory: (batch, 1, dim) or (batch, dim)
|
- memory: (batch, 1, dim) or (batch, dim)
|
||||||
|
@ -119,7 +113,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
- rnn_state: (batch, out_dim)
|
- rnn_state: (batch, out_dim)
|
||||||
- annots: (batch, max_time, annot_dim)
|
- annots: (batch, max_time, annot_dim)
|
||||||
- atten: (batch, 2, max_time)
|
- atten: (batch, 2, max_time)
|
||||||
- annot_lens: (batch,)
|
- mask: (batch,)
|
||||||
"""
|
"""
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
|
@ -133,8 +127,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
alignment = self.alignment_model(annots, rnn_output)
|
alignment = self.alignment_model(annots, rnn_output)
|
||||||
else:
|
else:
|
||||||
alignment = self.alignment_model(annots, rnn_output, atten)
|
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||||
if annot_lens is not None:
|
if mask is not None:
|
||||||
mask = sequence_mask(annot_lens)
|
|
||||||
mask = mask.view(memory.size(0), -1)
|
mask = mask.view(memory.size(0), -1)
|
||||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||||
# Normalize context weight
|
# Normalize context weight
|
||||||
|
|
|
@ -189,7 +189,7 @@ class CBHG(nn.Module):
|
||||||
x = highway(x)
|
x = highway(x)
|
||||||
# (B, T_in, hid_features*2)
|
# (B, T_in, hid_features*2)
|
||||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||||
# self.gru.flatten_parameters()
|
self.gru.flatten_parameters()
|
||||||
outputs, _ = self.gru(x)
|
outputs, _ = self.gru(x)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class Decoder(nn.Module):
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.stopnet = StopNet(r, memory_dim)
|
self.stopnet = StopNet(r, memory_dim)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None, input_lens=None):
|
def forward(self, inputs, memory=None, mask=None):
|
||||||
"""
|
"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -280,7 +280,7 @@ class Decoder(nn.Module):
|
||||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||||
output as the input.
|
output as the input.
|
||||||
input_lens (None): Time length of each input in batch.
|
mask (None): Attention mask for sequence padding.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
|
@ -332,7 +332,7 @@ class Decoder(nn.Module):
|
||||||
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
inputs, attention_cat, input_lens)
|
inputs, attention_cat, mask)
|
||||||
attention_cum += attention
|
attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
|
|
|
@ -25,14 +25,14 @@ class Tacotron(nn.Module):
|
||||||
self.postnet = PostCBHG(mel_dim)
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Linear(256, linear_dim)
|
self.last_linear = nn.Linear(256, linear_dim)
|
||||||
|
|
||||||
def forward(self, characters, mel_specs=None, text_lens=None):
|
def forward(self, characters, mel_specs=None, mask=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# batch x time x dim*r
|
# batch x time x dim*r
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, text_lens)
|
encoder_outputs, mel_specs, mask)
|
||||||
# Reshape
|
# Reshape
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
Loading…
Reference in New Issue