Set attention norm method by config.json

This commit is contained in:
Eren Golge 2019-03-26 00:48:12 +01:00
parent 786510cd6a
commit 0a92c6d5a7
8 changed files with 31 additions and 16 deletions

View File

@ -39,6 +39,7 @@
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode. "windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention. "batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16, "eval_batch_size":16,

View File

@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
class AttentionRNNCell(nn.Module): class AttentionRNNCell(nn.Module):
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False): def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"):
r""" r"""
General Attention RNN wrapper General Attention RNN wrapper
@ -112,6 +112,7 @@ class AttentionRNNCell(nn.Module):
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
windowing (bool): attention windowing forcing monotonic attention. windowing (bool): attention windowing forcing monotonic attention.
It is only active in eval mode. It is only active in eval mode.
norm (str): norm method to compute alignment weights.
""" """
super(AttentionRNNCell, self).__init__() super(AttentionRNNCell, self).__init__()
self.align_model = align_model self.align_model = align_model
@ -121,7 +122,7 @@ class AttentionRNNCell(nn.Module):
self.win_back = 3 self.win_back = 3
self.win_front = 6 self.win_front = 6
self.win_idx = None self.win_idx = None
# pick bahdanau or location sensitive attention self.norm = norm
if align_model == 'b': if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
out_dim) out_dim)
@ -164,7 +165,12 @@ class AttentionRNNCell(nn.Module):
alignment[:, front_win:] = -float("inf") alignment[:, front_win:] = -float("inf")
# Update the window # Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item() self.win_idx = torch.argmax(alignment,1).long()[0].item()
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) if self.norm == "softmax":
alignment = torch.softmax(alignment, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
context = torch.bmm(alignment.unsqueeze(1), annots) context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1) context = context.squeeze(1)
return rnn_output, context, alignment return rnn_output, context, alignment

View File

@ -302,7 +302,7 @@ class Decoder(nn.Module):
""" """
def __init__(self, in_features, memory_dim, r, memory_size, def __init__(self, in_features, memory_dim, r, memory_size,
attn_windowing): attn_windowing, attn_norm):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features self.in_features = in_features
@ -319,7 +319,8 @@ class Decoder(nn.Module):
annot_dim=in_features, annot_dim=in_features,
memory_dim=128, memory_dim=128,
align_model='ls', align_model='ls',
windowing=attn_windowing) windowing=attn_windowing,
norm=attn_norm)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256) self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state # decoder_RNN_input -> |RNN| -> RNN_state

View File

@ -112,7 +112,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size, attention_location_n_filters, attention_location_kernel_size,
windowing): windowing, norm):
super(Attention, self).__init__() super(Attention, self).__init__()
self.query_layer = Linear( self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
@ -128,6 +128,7 @@ class Attention(nn.Module):
self.win_back = 1 self.win_back = 1
self.win_front = 3 self.win_front = 3
self.win_idx = None self.win_idx = None
self.norm = norm
def init_win_idx(self): def init_win_idx(self):
self.win_idx = -1 self.win_idx = -1
@ -163,8 +164,13 @@ class Attention(nn.Module):
attention[:, 0] = attention.max() attention[:, 0] = attention.max()
# Update the window # Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item() self.win_idx = torch.argmax(attention, 1).long()[0].item()
alignment = torch.sigmoid(attention) / torch.sigmoid( if self.norm == "softmax":
attention).sum(dim=1).unsqueeze(1) alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
context = torch.bmm(alignment.unsqueeze(1), inputs) context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1) context = context.squeeze(1)
return context, alignment return context, alignment
@ -237,7 +243,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/ # adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win): def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.mel_channels = inputs_dim self.mel_channels = inputs_dim
self.r = r self.r = r
@ -257,7 +263,7 @@ class Decoder(nn.Module):
self.attention_rnn_dim) self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features, self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win) 128, 32, 31, attn_win, attn_norm)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1) self.decoder_rnn_dim, 1)

View File

@ -14,7 +14,8 @@ class Tacotron(nn.Module):
r=5, r=5,
padding_idx=None, padding_idx=None,
memory_size=5, memory_size=5,
attn_windowing=False): attn_windowing=False,
attn_norm="sigmoid"):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.mel_dim = mel_dim
@ -22,7 +23,7 @@ class Tacotron(nn.Module):
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx) self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing, attn_norm)
self.postnet = PostCBHG(mel_dim) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential( self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),

View File

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron # TODO: match function arguments with tacotron
class Tacotron2(nn.Module): class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False): def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"):
super(Tacotron2, self).__init__() super(Tacotron2, self).__init__()
self.n_mel_channels = 80 self.n_mel_channels = 80
self.n_frames_per_step = r self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val) self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512) self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win) self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm)
self.postnet = Postnet(self.n_mel_channels) self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

View File

@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase): class DecoderTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False) layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
dummy_input = T.rand(4, 8, 256) dummy_input = T.rand(4, 8, 256)
dummy_memory = T.rand(4, 2, 80) dummy_memory = T.rand(4, 2, 80)

View File

@ -375,7 +375,7 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"]) c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = MyModel(num_chars=num_chars, r=c.r) model = MyModel(num_chars=num_chars, r=c.r, attention_norm=c.attention_norm)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)