mirror of https://github.com/coqui-ai/TTS.git
Enalbe attention windowing and make in configurable at model level.
This commit is contained in:
parent
9927664f27
commit
ed1f648b83
|
@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
|
|||
|
||||
|
||||
class AttentionRNNCell(nn.Module):
|
||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
|
||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
|
||||
r"""
|
||||
General Attention RNN wrapper
|
||||
|
||||
|
@ -110,10 +110,17 @@ class AttentionRNNCell(nn.Module):
|
|||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
windowing (bool): attention windowing forcing monotonic attention.
|
||||
It is only active in eval mode.
|
||||
"""
|
||||
super(AttentionRNNCell, self).__init__()
|
||||
self.align_model = align_model
|
||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||
self.windowing = windowing
|
||||
if self.windowing:
|
||||
self.win_back = 1
|
||||
self.win_front = 3
|
||||
self.win_idx = None
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||
|
@ -138,6 +145,7 @@ class AttentionRNNCell(nn.Module):
|
|||
"""
|
||||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
self.win_idx = 0
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||
|
@ -151,6 +159,17 @@ class AttentionRNNCell(nn.Module):
|
|||
if mask is not None:
|
||||
mask = mask.view(memory.size(0), -1)
|
||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Windowing
|
||||
if not self.training:
|
||||
# print(" > Windowing active")
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
alignment[:, :back_win] = -float("inf")
|
||||
if front_win < memory.shape[1]:
|
||||
alignment[:, front_win:] = -float("inf")
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||
# Normalize context weight
|
||||
# alignment = F.softmax(alignment, dim=-1)
|
||||
# alignment = 5 * alignment
|
||||
|
|
|
@ -304,7 +304,7 @@ class Decoder(nn.Module):
|
|||
r (int): number of outputs per time step.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
def __init__(self, in_features, memory_dim, r, attn_windowing):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
|
@ -318,7 +318,8 @@ class Decoder(nn.Module):
|
|||
rnn_dim=256,
|
||||
annot_dim=in_features,
|
||||
memory_dim=128,
|
||||
align_model='ls')
|
||||
align_model='ls',
|
||||
windowing=attn_windowing)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
|
|
@ -12,7 +12,8 @@ class Tacotron(nn.Module):
|
|||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
r=5,
|
||||
padding_idx=None):
|
||||
padding_idx=None,
|
||||
attn_windowing=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
|
@ -24,7 +25,7 @@ class Tacotron(nn.Module):
|
|||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(embedding_dim)
|
||||
self.decoder = Decoder(256, mel_dim, r)
|
||||
self.decoder = Decoder(256, mel_dim, r, attn_windowing)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
|
|
Loading…
Reference in New Issue