mirror of https://github.com/coqui-ai/TTS.git
Set attention norm method by config.json
This commit is contained in:
parent
786510cd6a
commit
0a92c6d5a7
|
@ -39,6 +39,7 @@
|
||||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||||
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||||
|
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||||
|
|
||||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
|
|
|
@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNNCell(nn.Module):
|
class AttentionRNNCell(nn.Module):
|
||||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
|
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"):
|
||||||
r"""
|
r"""
|
||||||
General Attention RNN wrapper
|
General Attention RNN wrapper
|
||||||
|
|
||||||
|
@ -112,6 +112,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
windowing (bool): attention windowing forcing monotonic attention.
|
windowing (bool): attention windowing forcing monotonic attention.
|
||||||
It is only active in eval mode.
|
It is only active in eval mode.
|
||||||
|
norm (str): norm method to compute alignment weights.
|
||||||
"""
|
"""
|
||||||
super(AttentionRNNCell, self).__init__()
|
super(AttentionRNNCell, self).__init__()
|
||||||
self.align_model = align_model
|
self.align_model = align_model
|
||||||
|
@ -121,7 +122,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
self.win_back = 3
|
self.win_back = 3
|
||||||
self.win_front = 6
|
self.win_front = 6
|
||||||
self.win_idx = None
|
self.win_idx = None
|
||||||
# pick bahdanau or location sensitive attention
|
self.norm = norm
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||||
out_dim)
|
out_dim)
|
||||||
|
@ -164,7 +165,12 @@ class AttentionRNNCell(nn.Module):
|
||||||
alignment[:, front_win:] = -float("inf")
|
alignment[:, front_win:] = -float("inf")
|
||||||
# Update the window
|
# Update the window
|
||||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
if self.norm == "softmax":
|
||||||
|
alignment = torch.softmax(alignment, dim=-1)
|
||||||
|
elif self.norm == "sigmoid":
|
||||||
|
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown value for attention norm type")
|
||||||
context = torch.bmm(alignment.unsqueeze(1), annots)
|
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return rnn_output, context, alignment
|
return rnn_output, context, alignment
|
||||||
|
|
|
@ -302,7 +302,7 @@ class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, memory_size,
|
def __init__(self, in_features, memory_dim, r, memory_size,
|
||||||
attn_windowing):
|
attn_windowing, attn_norm):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
|
@ -319,7 +319,8 @@ class Decoder(nn.Module):
|
||||||
annot_dim=in_features,
|
annot_dim=in_features,
|
||||||
memory_dim=128,
|
memory_dim=128,
|
||||||
align_model='ls',
|
align_model='ls',
|
||||||
windowing=attn_windowing)
|
windowing=attn_windowing,
|
||||||
|
norm=attn_norm)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
|
|
@ -112,7 +112,7 @@ class LocationLayer(nn.Module):
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||||
attention_location_n_filters, attention_location_kernel_size,
|
attention_location_n_filters, attention_location_kernel_size,
|
||||||
windowing):
|
windowing, norm):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.query_layer = Linear(
|
self.query_layer = Linear(
|
||||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
|
@ -128,6 +128,7 @@ class Attention(nn.Module):
|
||||||
self.win_back = 1
|
self.win_back = 1
|
||||||
self.win_front = 3
|
self.win_front = 3
|
||||||
self.win_idx = None
|
self.win_idx = None
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
def init_win_idx(self):
|
def init_win_idx(self):
|
||||||
self.win_idx = -1
|
self.win_idx = -1
|
||||||
|
@ -163,8 +164,13 @@ class Attention(nn.Module):
|
||||||
attention[:, 0] = attention.max()
|
attention[:, 0] = attention.max()
|
||||||
# Update the window
|
# Update the window
|
||||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
if self.norm == "softmax":
|
||||||
attention).sum(dim=1).unsqueeze(1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
|
elif self.norm == "sigmoid":
|
||||||
|
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||||
|
attention).sum(dim=1).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown value for attention norm type")
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return context, alignment
|
return context, alignment
|
||||||
|
@ -237,7 +243,7 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win):
|
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mel_channels = inputs_dim
|
self.mel_channels = inputs_dim
|
||||||
self.r = r
|
self.r = r
|
||||||
|
@ -257,7 +263,7 @@ class Decoder(nn.Module):
|
||||||
self.attention_rnn_dim)
|
self.attention_rnn_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||||
128, 32, 31, attn_win)
|
128, 32, 31, attn_win, attn_norm)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
|
@ -14,7 +14,8 @@ class Tacotron(nn.Module):
|
||||||
r=5,
|
r=5,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
memory_size=5,
|
memory_size=5,
|
||||||
attn_windowing=False):
|
attn_windowing=False,
|
||||||
|
attn_norm="sigmoid"):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
|
@ -22,7 +23,7 @@ class Tacotron(nn.Module):
|
||||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
|
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing, attn_norm)
|
||||||
self.postnet = PostCBHG(mel_dim)
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Sequential(
|
self.last_linear = nn.Sequential(
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||||
|
|
|
@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
# TODO: match function arguments with tacotron
|
# TODO: match function arguments with tacotron
|
||||||
class Tacotron2(nn.Module):
|
class Tacotron2(nn.Module):
|
||||||
def __init__(self, num_chars, r, attn_win=False):
|
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"):
|
||||||
super(Tacotron2, self).__init__()
|
super(Tacotron2, self).__init__()
|
||||||
self.n_mel_channels = 80
|
self.n_mel_channels = 80
|
||||||
self.n_frames_per_step = r
|
self.n_frames_per_step = r
|
||||||
|
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(512)
|
||||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win)
|
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm)
|
||||||
self.postnet = Postnet(self.n_mel_channels)
|
self.postnet = Postnet(self.n_mel_channels)
|
||||||
|
|
||||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
|
|
@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False)
|
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
|
|
2
train.py
2
train.py
|
@ -375,7 +375,7 @@ def main(args):
|
||||||
init_distributed(args.rank, num_gpus, args.group_id,
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
c.distributed["backend"], c.distributed["url"])
|
c.distributed["backend"], c.distributed["url"])
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
model = MyModel(num_chars=num_chars, r=c.r)
|
model = MyModel(num_chars=num_chars, r=c.r, attention_norm=c.attention_norm)
|
||||||
|
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue