mirror of https://github.com/coqui-ai/TTS.git
A major bug fix for location sensitive attention.
This commit is contained in:
parent
6bcec24d13
commit
7acf4eab94
|
@ -69,10 +69,18 @@ class LocationSensitiveAttention(nn.Module):
|
|||
|
||||
|
||||
class AttentionRNN(nn.Module):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||
super(AttentionRNN, self).__init__()
|
||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||
if align_model == 'ls':
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||
else:
|
||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||
|
||||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
|
@ -88,7 +96,10 @@ class AttentionRNN(nn.Module):
|
|||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
if attnetion_vec is None:
|
||||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
|
||||
# TODO: needs recheck.
|
||||
if mask is not None:
|
||||
|
|
|
@ -212,12 +212,10 @@ class Decoder(nn.Module):
|
|||
eps (float): threshold for detecting the end of a sentence.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
super(Decoder, self).__init__()
|
||||
self.mode = mode
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
self.eps = eps
|
||||
self.r = r
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
|
@ -268,8 +266,8 @@ class Decoder(nn.Module):
|
|||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
attention_vec = inputs.data.new(B, T).zero_()
|
||||
attention_vec_cum = inputs.data.new(B, T).zero_()
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
memory = memory.transpose(0, 1)
|
||||
|
@ -287,12 +285,12 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
||||
attention_vec_cum.unsqueeze(1) / (t + 1)),
|
||||
attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
attention_cum.unsqueeze(1) / (t + 1)),
|
||||
dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
||||
attention_vec_cum += attention_vec
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat)
|
||||
attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
|
@ -333,6 +331,7 @@ class Decoder(nn.Module):
|
|||
class StopNet(nn.Module):
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
"""Predicts the stop token to stop the decoder at testing time"""
|
||||
super(StopNet, self).__init__()
|
||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||
self.relu = nn.ReLU()
|
||||
|
|
Loading…
Reference in New Issue