From 0a92c6d5a7601fe0b1d8d5bf53ef1774c15647cc Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 26 Mar 2019 00:48:12 +0100 Subject: [PATCH] Set attention norm method by config.json --- config_cluster.json | 1 + layers/attention.py | 12 +++++++++--- layers/tacotron.py | 5 +++-- layers/tacotron2.py | 16 +++++++++++----- models/tacotron.py | 5 +++-- models/tacotron2.py | 4 ++-- tests/layers_tests.py | 2 +- train.py | 2 +- 8 files changed, 31 insertions(+), 16 deletions(-) diff --git a/config_cluster.json b/config_cluster.json index 79a8e47f..96723b5c 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -39,6 +39,7 @@ "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "windowing": false, // Enables attention windowing. Used only in eval mode. "memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, diff --git a/layers/attention.py b/layers/attention.py index c59ce406..08765e70 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module): class AttentionRNNCell(nn.Module): - def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False): + def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"): r""" General Attention RNN wrapper @@ -112,6 +112,7 @@ class AttentionRNNCell(nn.Module): align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. windowing (bool): attention windowing forcing monotonic attention. It is only active in eval mode. + norm (str): norm method to compute alignment weights. """ super(AttentionRNNCell, self).__init__() self.align_model = align_model @@ -121,7 +122,7 @@ class AttentionRNNCell(nn.Module): self.win_back = 3 self.win_front = 6 self.win_idx = None - # pick bahdanau or location sensitive attention + self.norm = norm if align_model == 'b': self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim) @@ -164,7 +165,12 @@ class AttentionRNNCell(nn.Module): alignment[:, front_win:] = -float("inf") # Update the window self.win_idx = torch.argmax(alignment,1).long()[0].item() - alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) + if self.norm == "softmax": + alignment = torch.softmax(alignment, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) + else: + raise RuntimeError("Unknown value for attention norm type") context = torch.bmm(alignment.unsqueeze(1), annots) context = context.squeeze(1) return rnn_output, context, alignment diff --git a/layers/tacotron.py b/layers/tacotron.py index b21a8882..3d505ddb 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -302,7 +302,7 @@ class Decoder(nn.Module): """ def __init__(self, in_features, memory_dim, r, memory_size, - attn_windowing): + attn_windowing, attn_norm): super(Decoder, self).__init__() self.r = r self.in_features = in_features @@ -319,7 +319,8 @@ class Decoder(nn.Module): annot_dim=in_features, memory_dim=128, align_model='ls', - windowing=attn_windowing) + windowing=attn_windowing, + norm=attn_norm) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 8275788b..0935c3bb 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -112,7 +112,7 @@ class LocationLayer(nn.Module): class Attention(nn.Module): def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, - windowing): + windowing, norm): super(Attention, self).__init__() self.query_layer = Linear( attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') @@ -128,6 +128,7 @@ class Attention(nn.Module): self.win_back = 1 self.win_front = 3 self.win_idx = None + self.norm = norm def init_win_idx(self): self.win_idx = -1 @@ -163,8 +164,13 @@ class Attention(nn.Module): attention[:, 0] = attention.max() # Update the window self.win_idx = torch.argmax(attention, 1).long()[0].item() - alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum(dim=1).unsqueeze(1) + if self.norm == "softmax": + alignment = torch.softmax(attention, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(attention) / torch.sigmoid( + attention).sum(dim=1).unsqueeze(1) + else: + raise RuntimeError("Unknown value for attention norm type") context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) return context, alignment @@ -237,7 +243,7 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -257,7 +263,7 @@ class Decoder(nn.Module): self.attention_rnn_dim) self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, 32, 31, attn_win) + 128, 32, 31, attn_win, attn_norm) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) diff --git a/models/tacotron.py b/models/tacotron.py index 11ff8740..7bda5ea2 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -14,7 +14,8 @@ class Tacotron(nn.Module): r=5, padding_idx=None, memory_size=5, - attn_windowing=False): + attn_windowing=False, + attn_norm="sigmoid"): super(Tacotron, self).__init__() self.r = r self.mel_dim = mel_dim @@ -22,7 +23,7 @@ class Tacotron(nn.Module): self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx) self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(256) - self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing) + self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing, attn_norm) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), diff --git a/models/tacotron2.py b/models/tacotron2.py index e4082848..671e6bb8 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False): + def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +18,7 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 4ac7c8fc..3ddd8f8e 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase): class DecoderTests(unittest.TestCase): def test_in_out(self): - layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False) + layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid") dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) diff --git a/train.py b/train.py index 09d39ead..e4e4a674 100644 --- a/train.py +++ b/train.py @@ -375,7 +375,7 @@ def main(args): init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) - model = MyModel(num_chars=num_chars, r=c.r) + model = MyModel(num_chars=num_chars, r=c.r, attention_norm=c.attention_norm) print(" | > Num output units : {}".format(ap.num_freq), flush=True)