setup_model externally based on model selection. Make forward attention and prenet type configurable in config.json

This commit is contained in:
Eren Golge 2019-04-05 17:49:18 +02:00
parent 043e49f367
commit 961af0f5cd
6 changed files with 78 additions and 28 deletions

View File

@ -38,8 +38,10 @@
"lr_decay": false, // if true, Noam learning rate decaying is applied through training. "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode. "windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16, "eval_batch_size":16,
@ -49,7 +51,7 @@
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 100, // Number of steps to log traning on console. "print_step": 100, // Number of steps to log traning on console.
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_group_size": 8, //Number of batches to shuffle after bucketing. "batch_group_size": 8, // Number of batches to shuffle after bucketing.
"run_eval": true, "run_eval": true,
"test_delay_epochs": 2, //Until attention is aligned, testing only wastes computation time. "test_delay_epochs": 2, //Until attention is aligned, testing only wastes computation time.

View File

@ -38,8 +38,10 @@
"lr_decay": false, // if true, Noam learning rate decaying is applied through training. "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode. "windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16, "eval_batch_size":16,

View File

@ -53,17 +53,27 @@ class LinearBN(nn.Module):
class Prenet(nn.Module): class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 256]): def __init__(self, in_features, prenet_type, out_features=[256, 256]):
super(Prenet, self).__init__() super(Prenet, self).__init__()
self.prenet_type = prenet_type
in_features = [in_features] + out_features[:-1] in_features = [in_features] + out_features[:-1]
if prenet_type == "original":
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=False) LinearBN(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features) for (in_size, out_size) in zip(in_features, out_features)
]) ])
elif prenet_type == "bn":
self.layers = nn.ModuleList(
[Linear(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x): def forward(self, x):
for linear in self.layers: for linear in self.layers:
if self.prenet_type == "original":
x = F.relu(linear(x)) x = F.relu(linear(x))
elif self.prenet_type == "bn":
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
return x return x
@ -112,7 +122,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size, attention_location_n_filters, attention_location_kernel_size,
windowing, norm): windowing, norm, forward_attn):
super(Attention, self).__init__() super(Attention, self).__init__()
self.query_layer = Linear( self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
@ -126,12 +136,22 @@ class Attention(nn.Module):
self.windowing = windowing self.windowing = windowing
self.win_idx = None self.win_idx = None
self.norm = norm self.norm = norm
self.forward_attn = forward_attn
def init_win_idx(self): def init_win_idx(self):
self.win_idx = -1 self.win_idx = -1
self.win_back = 1 self.win_back = 1
self.win_front = 3 self.win_front = 3
def init_forward_attn_state(self, inputs):
"""
Init forward attention states
"""
B = inputs.shape[0]
T = inputs.shape[1]
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
def get_attention(self, query, processed_inputs, attention_cat): def get_attention(self, query, processed_inputs, attention_cat):
processed_query = self.query_layer(query.unsqueeze(1)) processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_cat) processed_attention_weights = self.location_layer(attention_cat)
@ -170,9 +190,19 @@ class Attention(nn.Module):
attention).sum(dim=1).unsqueeze(1) attention).sum(dim=1).unsqueeze(1)
else: else:
raise RuntimeError("Unknown value for attention norm type") raise RuntimeError("Unknown value for attention norm type")
if self.forward_attn:
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alpha_norm, alignment
else:
context = torch.bmm(alignment.unsqueeze(1), inputs) context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1) context = context.squeeze(1)
return context, alignment return context, alignment, alignment
class Postnet(nn.Module): class Postnet(nn.Module):
@ -242,7 +272,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/ # adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm): def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.mel_channels = inputs_dim self.mel_channels = inputs_dim
self.r = r self.r = r
@ -255,14 +285,14 @@ class Decoder(nn.Module):
self.p_attention_dropout = 0.1 self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1 self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels * r, self.prenet = Prenet(self.mel_channels * r, prenet_type,
[self.prenet_dim, self.prenet_dim]) [self.prenet_dim, self.prenet_dim])
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim) self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features, self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win, attn_norm) 128, 32, 31, attn_win, attn_norm, forward_attn)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1) self.decoder_rnn_dim, 1)
@ -340,11 +370,11 @@ class Decoder(nn.Module):
attention_cat = torch.cat((self.attention_weights.unsqueeze(1), attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), self.attention_weights_cum.unsqueeze(1)),
dim=1) dim=1)
self.context, self.attention_weights = self.attention_layer( self.context, self.attention_weights, alignments = self.attention_layer(
self.attention_hidden, self.inputs, self.processed_inputs, self.attention_hidden, self.inputs, self.processed_inputs,
attention_cat, self.mask) attention_cat, self.mask)
self.attention_weights_cum += self.attention_weights self.attention_weights_cum += alignments
memory = torch.cat( memory = torch.cat(
(self.attention_hidden, self.context), -1) (self.attention_hidden, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
@ -372,6 +402,8 @@ class Decoder(nn.Module):
memories = self.prenet(memories) memories = self.prenet(memories)
self._init_states(inputs, mask=mask) self._init_states(inputs, mask=mask)
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments = [], [], [] outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1: while len(outputs) < memories.size(0) - 1:
@ -392,6 +424,9 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None) self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx() self.attention_layer.init_win_idx()
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0 outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False] stop_flags = [True, False, False]
stop_count = 0 stop_count = 0
@ -433,8 +468,9 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None, keep_states=True) self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx() self.attention_layer.init_win_idx()
self.attention_layer.init_forward_attn_state()
outputs, gate_outputs, alignments, t = [], [], [], 0 outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False] stop_flags = [False, False, False]
stop_count = 0 stop_count = 0
while True: while True:
memory = self.prenet(self.memory_truncated) memory = self.prenet(self.memory_truncated)
@ -454,6 +490,7 @@ class Decoder(nn.Module):
elif len(outputs) == self.max_decoder_steps: elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
self.memory_truncated = mel_output self.memory_truncated = mel_output
t += 1 t += 1

View File

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron # TODO: match function arguments with tacotron
class Tacotron2(nn.Module): class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"): def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
super(Tacotron2, self).__init__() super(Tacotron2, self).__init__()
self.n_mel_channels = 80 self.n_mel_channels = 80
self.n_frames_per_step = r self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val) self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512) self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm) self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
self.postnet = Postnet(self.n_mel_channels) self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

View File

@ -23,7 +23,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, lr_decay, load_config, lr_decay,
remove_experiment_folder, save_best_model, remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay, save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file) set_init_dict, copy_config_file, setup_model)
from utils.logger import Logger from utils.logger import Logger
from utils.synthesis import synthesis from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols from utils.text.symbols import phonemes, symbols
@ -375,7 +375,7 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"]) c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm) model = setup_model(num_chars, c)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
@ -528,9 +528,6 @@ if __name__ == '__main__':
# Conditional imports # Conditional imports
preprocessor = importlib.import_module('datasets.preprocess') preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower()) preprocessor = getattr(preprocessor, c.dataset.lower())
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = getattr(MyModel, c.model)
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)

View File

@ -8,6 +8,7 @@ import datetime
import json import json
import torch import torch
import subprocess import subprocess
import importlib
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from torch.autograd import Variable from torch.autograd import Variable
@ -237,3 +238,14 @@ def set_init_dict(model_dict, checkpoint, c):
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict))) print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
return model_dict return model_dict
def setup_model(num_chars, c):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() == "tacotron":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
return model