mirror of https://github.com/coqui-ai/TTS.git
renamed attention_rnn to query_rnn
This commit is contained in:
parent
f08b31fd8d
commit
98edb7a4f8
|
@ -108,19 +108,19 @@ class LocationLayer(nn.Module):
|
|||
class Attention(nn.Module):
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
def __init__(self, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent, forward_attn_mask):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
query_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(
|
||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
query_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(
|
||||
attention_dim,
|
||||
|
@ -203,11 +203,12 @@ class Attention(nn.Module):
|
|||
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||
(1, 0, 0, 0)).to(inputs.device)
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
|
||||
(1, 0, 0, 0))
|
||||
# compute transition potentials
|
||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
alpha = ((1 - self.u) * self.alpha
|
||||
+ self.u * prev_alpha
|
||||
+ 1e-8) * alignment
|
||||
# force incremental alignment
|
||||
if not self.training and self.forward_attn_mask:
|
||||
_, n = prev_alpha.max(1)
|
||||
|
@ -231,19 +232,20 @@ class Attention(nn.Module):
|
|||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
if self.location_attention:
|
||||
attention, processed_query = self.get_location_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
query, processed_inputs)
|
||||
else:
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
query, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
# apply windowing - only in eval mode
|
||||
if not self.training and self.windowing:
|
||||
attention = self.apply_windowing(attention, inputs)
|
||||
|
||||
# normalize attention values
|
||||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(attention, dim=-1)
|
||||
|
@ -258,7 +260,7 @@ class Attention(nn.Module):
|
|||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
context, self.attention_weights = self.apply_forward_attention(
|
||||
inputs, alignment, attention_hidden_state)
|
||||
inputs, alignment, query)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
|
|
|
@ -283,6 +283,7 @@ class Decoder(nn.Module):
|
|||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.query_dim = 256
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(
|
||||
memory_dim * self.memory_size,
|
||||
|
@ -290,18 +291,18 @@ class Decoder(nn.Module):
|
|||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||
self.attention = Attention(query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -310,7 +311,7 @@ class Decoder(nn.Module):
|
|||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
# learn init values instead of zero init.
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.query_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
|
@ -347,18 +348,18 @@ class Decoder(nn.Module):
|
|||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
# decoder states
|
||||
self.attention_rnn_hidden = self.attention_rnn_init(
|
||||
self.query = self.query_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
self.attention = inputs.data.new(B, T).zero_()
|
||||
self.attention_cum = inputs.data.new(B, T).zero_()
|
||||
# cache attention inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||
|
||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||
# Back to batch first
|
||||
|
@ -370,13 +371,15 @@ class Decoder(nn.Module):
|
|||
def decode(self, inputs, mask=None):
|
||||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
|
||||
# Attention RNN
|
||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||
|
||||
# Concat query and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
||||
-1))
|
||||
torch.cat((self.query, self.context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
|
@ -384,18 +387,18 @@ class Decoder(nn.Module):
|
|||
# Residual connection
|
||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
del decoder_input
|
||||
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
|
||||
# predict stop token
|
||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||
del decoder_output
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
return output, stop_token, self.attention_layer.attention_weights
|
||||
return output, stop_token, self.attention.attention_weights
|
||||
|
||||
def _update_memory_queue(self, new_memory):
|
||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||
|
@ -427,7 +430,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens = []
|
||||
t = 0
|
||||
self._init_states(inputs)
|
||||
self.attention_layer.init_states(inputs)
|
||||
self.attention.init_states(inputs)
|
||||
while len(outputs) < memory.size(0):
|
||||
if t > 0:
|
||||
new_memory = memory[t - 1]
|
||||
|
@ -453,8 +456,8 @@ class Decoder(nn.Module):
|
|||
stop_tokens = []
|
||||
t = 0
|
||||
self._init_states(inputs)
|
||||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_states(inputs)
|
||||
self.attention.init_win_idx()
|
||||
self.attention.init_states(inputs)
|
||||
while True:
|
||||
if t > 0:
|
||||
new_memory = outputs[-1]
|
||||
|
|
|
@ -104,7 +104,7 @@ class Decoder(nn.Module):
|
|||
self.r = r
|
||||
self.encoder_embedding_dim = in_features
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.attention_rnn_dim = 1024
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.max_decoder_steps = 1000
|
||||
|
@ -116,22 +116,22 @@ class Decoder(nn.Module):
|
|||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn_dim)
|
||||
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
|
||||
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_win,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
self.attention = Attention(query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_win,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||
|
@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
|||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
||||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
||||
self.query_rnn_init = nn.Embedding(1, self.query_dim)
|
||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||
self.memory_truncated = None
|
||||
|
@ -160,10 +160,10 @@ class Decoder(nn.Module):
|
|||
# T = inputs.size(1)
|
||||
|
||||
if not keep_states:
|
||||
self.attention_hidden = self.attention_rnn_init(
|
||||
self.query = self.query_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.attention_cell = Variable(
|
||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
||||
self.query_rnn_cell_state = Variable(
|
||||
inputs.data.new(B, self.query_dim).zero_())
|
||||
|
||||
self.decoder_hidden = self.decoder_rnn_inits(
|
||||
inputs.data.new_zeros(B).long())
|
||||
|
@ -174,7 +174,7 @@ class Decoder(nn.Module):
|
|||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||
self.mask = mask
|
||||
|
||||
def _reshape_memory(self, memories):
|
||||
|
@ -193,18 +193,18 @@ class Decoder(nn.Module):
|
|||
return outputs, stop_tokens, alignments
|
||||
|
||||
def decode(self, memory):
|
||||
cell_input = torch.cat((memory, self.context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||
self.attention_cell = F.dropout(
|
||||
self.attention_cell, self.p_attention_dropout, self.training)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
self.query, self.query_rnn_cell_state = self.query_rnn(
|
||||
query_input, (self.query, self.query_rnn_cell_state))
|
||||
self.query = F.dropout(
|
||||
self.query, self.p_attention_dropout, self.training)
|
||||
self.query_rnn_cell_state = F.dropout(
|
||||
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||
|
||||
self.context = self.attention_layer(self.attention_hidden, self.inputs,
|
||||
self.processed_inputs, self.mask)
|
||||
self.context = self.attention(self.query, self.inputs,
|
||||
self.processed_inputs, self.mask)
|
||||
|
||||
memory = torch.cat((self.attention_hidden, self.context), -1)
|
||||
memory = torch.cat((self.query, self.context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
memory, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||
|
@ -223,7 +223,7 @@ class Decoder(nn.Module):
|
|||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
return decoder_output, stop_token, self.attention_layer.attention_weights
|
||||
return decoder_output, stop_token, self.attention.attention_weights
|
||||
|
||||
def forward(self, inputs, memories, mask):
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
|
@ -232,7 +232,7 @@ class Decoder(nn.Module):
|
|||
memories = self.prenet(memories)
|
||||
|
||||
self._init_states(inputs, mask=mask)
|
||||
self.attention_layer.init_states(inputs)
|
||||
self.attention.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
|
@ -251,8 +251,8 @@ class Decoder(nn.Module):
|
|||
memory = self.get_go_frame(inputs)
|
||||
self._init_states(inputs, mask=None)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_states(inputs)
|
||||
self.attention.init_win_idx()
|
||||
self.attention.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [True, False, False]
|
||||
|
@ -295,8 +295,8 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
self._init_states(inputs, mask=None, keep_states=True)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_states(inputs)
|
||||
self.attention.init_win_idx()
|
||||
self.attention.init_states(inputs)
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [True, False, False]
|
||||
stop_count = 0
|
||||
|
|
Loading…
Reference in New Issue