mirror of https://github.com/coqui-ai/TTS.git
Update decoder accepting location sensitive attention
This commit is contained in:
parent
3348d462d2
commit
d43b4ccc1e
|
@ -223,7 +223,7 @@ class Decoder(nn.Module):
|
||||||
self.r = r
|
self.r = r
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
|
@ -252,6 +252,7 @@ class Decoder(nn.Module):
|
||||||
- memory: batch x #mels_pecs x mel_spec_dim
|
- memory: batch x #mels_pecs x mel_spec_dim
|
||||||
"""
|
"""
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
T = inputs.size(1)
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = not self.training
|
greedy = not self.training
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
@ -269,11 +270,13 @@ class Decoder(nn.Module):
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
|
attention_vec = memory.data.new(B, T).zero_()
|
||||||
|
attention_vec_cum = memory.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
attentions = []
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
|
@ -286,8 +289,13 @@ class Decoder(nn.Module):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
attention_vec_cum.unsqueeze(1)),
|
||||||
|
dim=1)
|
||||||
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
|
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
||||||
|
attention_vec_cum += attention_vec
|
||||||
|
attention_vec_cum /= (t + 1)
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
@ -304,14 +312,14 @@ class Decoder(nn.Module):
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
@ -319,10 +327,10 @@ class Decoder(nn.Module):
|
||||||
break
|
break
|
||||||
assert greedy or len(outputs) == T_decoder
|
assert greedy or len(outputs) == T_decoder
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, attentions, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue