mirror of https://github.com/coqui-ai/TTS.git
loca sens attn fix
This commit is contained in:
parent
b84a3ba0c8
commit
9566097c07
|
@ -39,7 +39,7 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
padding = int((kernel_size - 1) / 2)
|
padding = int((kernel_size - 1) / 2)
|
||||||
self.loc_conv = nn.Conv1d(1, filters,
|
self.loc_conv = nn.Conv1d(2, filters,
|
||||||
kernel_size=kernel_size, stride=1,
|
kernel_size=kernel_size, stride=1,
|
||||||
padding=padding, bias=False)
|
padding=padding, bias=False)
|
||||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||||
|
@ -77,15 +77,15 @@ class AttentionRNNCell(nn.Module):
|
||||||
out_dim (int): context vector feature dimension.
|
out_dim (int): context vector feature dimension.
|
||||||
rnn_dim (int): rnn hidden state dimension.
|
rnn_dim (int): rnn hidden state dimension.
|
||||||
annot_dim (int): annotation vector feature dimension.
|
annot_dim (int): annotation vector feature dimension.
|
||||||
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
"""
|
"""
|
||||||
super(AttentionRNNCell, self).__init__()
|
super(AttentionRNNCell, self).__init__()
|
||||||
self.align_model = align_model
|
self.align_model = align_model
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim)
|
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
|
||||||
if align_model == 'ls':
|
if align_model == 'ls':
|
||||||
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -198,6 +198,7 @@ class Decoder(nn.Module):
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
|
self.in_features = in_features
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
|
@ -249,7 +250,7 @@ class Decoder(nn.Module):
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 128).zero_()
|
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
# attention states
|
# attention states
|
||||||
attention = inputs.data.new(B, T).zero_()
|
attention = inputs.data.new(B, T).zero_()
|
||||||
|
|
Loading…
Reference in New Issue