mirror of https://github.com/coqui-ai/TTS.git
Enalbe attention windowing and make in configurable at model level.
This commit is contained in:
parent
9927664f27
commit
ed1f648b83
|
@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNNCell(nn.Module):
|
class AttentionRNNCell(nn.Module):
|
||||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
|
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
|
||||||
r"""
|
r"""
|
||||||
General Attention RNN wrapper
|
General Attention RNN wrapper
|
||||||
|
|
||||||
|
@ -110,10 +110,17 @@ class AttentionRNNCell(nn.Module):
|
||||||
annot_dim (int): annotation vector feature dimension.
|
annot_dim (int): annotation vector feature dimension.
|
||||||
memory_dim (int): memory vector (decoder output) feature dimension.
|
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
|
windowing (bool): attention windowing forcing monotonic attention.
|
||||||
|
It is only active in eval mode.
|
||||||
"""
|
"""
|
||||||
super(AttentionRNNCell, self).__init__()
|
super(AttentionRNNCell, self).__init__()
|
||||||
self.align_model = align_model
|
self.align_model = align_model
|
||||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||||
|
self.windowing = windowing
|
||||||
|
if self.windowing:
|
||||||
|
self.win_back = 1
|
||||||
|
self.win_front = 3
|
||||||
|
self.win_idx = None
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||||
|
@ -138,6 +145,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
"""
|
"""
|
||||||
if t == 0:
|
if t == 0:
|
||||||
self.alignment_model.reset()
|
self.alignment_model.reset()
|
||||||
|
self.win_idx = 0
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||||
|
@ -151,6 +159,17 @@ class AttentionRNNCell(nn.Module):
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(memory.size(0), -1)
|
mask = mask.view(memory.size(0), -1)
|
||||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||||
|
# Windowing
|
||||||
|
if not self.training:
|
||||||
|
# print(" > Windowing active")
|
||||||
|
back_win = self.win_idx - self.win_back
|
||||||
|
front_win = self.win_idx + self.win_front
|
||||||
|
if back_win > 0:
|
||||||
|
alignment[:, :back_win] = -float("inf")
|
||||||
|
if front_win < memory.shape[1]:
|
||||||
|
alignment[:, front_win:] = -float("inf")
|
||||||
|
# Update the window
|
||||||
|
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||||
# Normalize context weight
|
# Normalize context weight
|
||||||
# alignment = F.softmax(alignment, dim=-1)
|
# alignment = F.softmax(alignment, dim=-1)
|
||||||
# alignment = 5 * alignment
|
# alignment = 5 * alignment
|
||||||
|
|
|
@ -304,7 +304,7 @@ class Decoder(nn.Module):
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r, attn_windowing):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
|
@ -318,7 +318,8 @@ class Decoder(nn.Module):
|
||||||
rnn_dim=256,
|
rnn_dim=256,
|
||||||
annot_dim=in_features,
|
annot_dim=in_features,
|
||||||
memory_dim=128,
|
memory_dim=128,
|
||||||
align_model='ls')
|
align_model='ls',
|
||||||
|
windowing=attn_windowing)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
|
|
@ -12,7 +12,8 @@ class Tacotron(nn.Module):
|
||||||
linear_dim=1025,
|
linear_dim=1025,
|
||||||
mel_dim=80,
|
mel_dim=80,
|
||||||
r=5,
|
r=5,
|
||||||
padding_idx=None):
|
padding_idx=None,
|
||||||
|
attn_windowing=False):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
|
@ -24,7 +25,7 @@ class Tacotron(nn.Module):
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(256, mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r, attn_windowing)
|
||||||
self.postnet = PostCBHG(mel_dim)
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Sequential(
|
self.last_linear = nn.Sequential(
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||||
|
|
Loading…
Reference in New Issue