mirror of https://github.com/coqui-ai/TTS.git
A major bug fix for location sensitive attention.
This commit is contained in:
parent
6bcec24d13
commit
7acf4eab94
|
@ -69,10 +69,18 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNN(nn.Module):
|
class AttentionRNN(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim):
|
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
|
# pick bahdanau or location sensitive attention
|
||||||
|
if align_model == 'b':
|
||||||
|
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||||
|
if align_model == 'ls':
|
||||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||||
|
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, memory, context, rnn_state, annotations,
|
def forward(self, memory, context, rnn_state, annotations,
|
||||||
attention_vec, mask=None, annotations_lengths=None):
|
attention_vec, mask=None, annotations_lengths=None):
|
||||||
|
@ -88,6 +96,9 @@ class AttentionRNN(nn.Module):
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
|
if attnetion_vec is None:
|
||||||
|
alignment = self.alignment_model(annotations, rnn_output)
|
||||||
|
else:
|
||||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||||
|
|
||||||
# TODO: needs recheck.
|
# TODO: needs recheck.
|
||||||
|
|
|
@ -212,12 +212,10 @@ class Decoder(nn.Module):
|
||||||
eps (float): threshold for detecting the end of a sentence.
|
eps (float): threshold for detecting the end of a sentence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
|
def __init__(self, in_features, memory_dim, r):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mode = mode
|
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.eps = eps
|
|
||||||
self.r = r
|
self.r = r
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
|
@ -268,8 +266,8 @@ class Decoder(nn.Module):
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
attention_vec = inputs.data.new(B, T).zero_()
|
attention = inputs.data.new(B, T).zero_()
|
||||||
attention_vec_cum = inputs.data.new(B, T).zero_()
|
attention_cum = inputs.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
|
@ -287,12 +285,12 @@ class Decoder(nn.Module):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
attention_cat = torch.cat((attention.unsqueeze(1),
|
||||||
attention_vec_cum.unsqueeze(1) / (t + 1)),
|
attention_cum.unsqueeze(1) / (t + 1)),
|
||||||
dim=1)
|
dim=1)
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat)
|
||||||
attention_vec_cum += attention_vec
|
attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
@ -333,6 +331,7 @@ class Decoder(nn.Module):
|
||||||
class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, r, memory_dim):
|
def __init__(self, r, memory_dim):
|
||||||
|
"""Predicts the stop token to stop the decoder at testing time"""
|
||||||
super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
Loading…
Reference in New Issue