mirror of https://github.com/coqui-ai/TTS.git
Update decoder accepting location sensitive attention
This commit is contained in:
parent
3348d462d2
commit
d43b4ccc1e
|
@ -223,7 +223,7 @@ class Decoder(nn.Module):
|
|||
self.r = r
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
|
@ -252,6 +252,7 @@ class Decoder(nn.Module):
|
|||
- memory: batch x #mels_pecs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# Run greedy decoding if memory is None
|
||||
greedy = not self.training
|
||||
if memory is not None:
|
||||
|
@ -269,11 +270,13 @@ class Decoder(nn.Module):
|
|||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
attention_vec = memory.data.new(B, T).zero_()
|
||||
attention_vec_cum = memory.data.new(B, T).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
memory = memory.transpose(0, 1)
|
||||
outputs = []
|
||||
alignments = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
memory_input = initial_memory
|
||||
|
@ -286,8 +289,13 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
||||
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
||||
attention_vec_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
||||
attention_vec_cum += attention_vec
|
||||
attention_vec_cum /= (t + 1)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
|
@ -304,14 +312,14 @@ class Decoder(nn.Module):
|
|||
# predict stop token
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||
outputs += [output]
|
||||
alignments += [alignment]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if (not greedy and self.training) or (greedy and memory is not None):
|
||||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
|
@ -319,10 +327,10 @@ class Decoder(nn.Module):
|
|||
break
|
||||
assert greedy or len(outputs) == T_decoder
|
||||
# Back to batch first
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, alignments, stop_tokens
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue