mirror of https://github.com/coqui-ai/TTS.git
Use Attention and Prenet from common file
This commit is contained in:
parent
0dbed8fef7
commit
35b76556e4
|
@ -0,0 +1,79 @@
|
|||
{
|
||||
"run_name": "mozilla-tacotron-tagent",
|
||||
"run_description": "using forward attention with transition agent, with original prenet, loss masking, separate stopnet, sigmoid norm. Compare this with 4841",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": false, // move normalization to range [-1, 1]
|
||||
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
},
|
||||
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"r": 5, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 10, // Number of steps to log traning on console.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
|
||||
"data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners"
|
||||
}
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class BahdanauAttention(nn.Module):
|
||||
def __init__(self, annot_dim, query_dim, attn_dim):
|
||||
super(BahdanauAttention, self).__init__()
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
|
||||
def forward(self, annots, query):
|
||||
"""
|
||||
Shapes:
|
||||
- annots: (batch, max_time, dim)
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
"""
|
||||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
query = query.unsqueeze(1)
|
||||
# (batch, 1, dim)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annots)
|
||||
# (batch, max_time, 1)
|
||||
alignment = self.v(torch.tanh(processed_query + processed_annots))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
class LocationSensitiveAttention(nn.Module):
|
||||
"""Location sensitive attention following
|
||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||
|
||||
def __init__(self,
|
||||
annot_dim,
|
||||
query_dim,
|
||||
attn_dim,
|
||||
kernel_size=31,
|
||||
filters=32):
|
||||
super(LocationSensitiveAttention, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
padding = [(kernel_size - 1) // 2, (kernel_size - 1) // 2]
|
||||
self.loc_conv = nn.Sequential(
|
||||
nn.ConstantPad1d(padding, 0),
|
||||
nn.Conv1d(
|
||||
2,
|
||||
filters,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False))
|
||||
self.loc_linear = nn.Linear(filters, attn_dim, bias=True)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.processed_annots = None
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.loc_linear.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.query_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.annot_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.v.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def reset(self):
|
||||
self.processed_annots = None
|
||||
|
||||
def forward(self, annot, query, loc):
|
||||
"""
|
||||
Shapes:
|
||||
- annot: (batch, max_time, dim)
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
- loc: (batch, 2, max_time)
|
||||
"""
|
||||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
query = query.unsqueeze(1)
|
||||
processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2))
|
||||
processed_query = self.query_layer(query)
|
||||
# cache annots
|
||||
if self.processed_annots is None:
|
||||
self.processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(
|
||||
torch.tanh(processed_query + self.processed_annots + processed_loc))
|
||||
del processed_loc
|
||||
del processed_query
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
class AttentionRNNCell(nn.Module):
|
||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"):
|
||||
r"""
|
||||
General Attention RNN wrapper
|
||||
|
||||
Args:
|
||||
out_dim (int): context vector feature dimension.
|
||||
rnn_dim (int): rnn hidden state dimension.
|
||||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
windowing (bool): attention windowing forcing monotonic attention.
|
||||
It is only active in eval mode.
|
||||
norm (str): norm method to compute alignment weights.
|
||||
"""
|
||||
super(AttentionRNNCell, self).__init__()
|
||||
self.align_model = align_model
|
||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||
self.windowing = windowing
|
||||
if self.windowing:
|
||||
self.win_back = 3
|
||||
self.win_front = 6
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||
out_dim)
|
||||
if align_model == 'ls':
|
||||
self.alignment_model = LocationSensitiveAttention(
|
||||
annot_dim, rnn_dim, out_dim)
|
||||
else:
|
||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(
|
||||
align_model))
|
||||
|
||||
def forward(self, memory, context, rnn_state, annots, atten, mask, t):
|
||||
"""
|
||||
Shapes:
|
||||
- memory: (batch, 1, dim) or (batch, dim)
|
||||
- context: (batch, dim)
|
||||
- rnn_state: (batch, out_dim)
|
||||
- annots: (batch, max_time, annot_dim)
|
||||
- atten: (batch, 2, max_time)
|
||||
- mask: (batch,)
|
||||
"""
|
||||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
self.win_idx = 0
|
||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||
if self.align_model is 'b':
|
||||
alignment = self.alignment_model(annots, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||
if mask is not None:
|
||||
mask = mask.view(memory.size(0), -1)
|
||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Windowing
|
||||
if not self.training and self.windowing:
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
alignment[:, :back_win] = -float("inf")
|
||||
if front_win < memory.shape[1]:
|
||||
alignment[:, front_win:] = -float("inf")
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(alignment, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||
context = context.squeeze(1)
|
||||
return rnn_output, context, alignment
|
|
@ -80,4 +80,179 @@ class Prenet(nn.Module):
|
|||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
else:
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
||||
attention_dim):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=31,
|
||||
stride=1,
|
||||
padding=(31 - 1) // 2,
|
||||
bias=False)
|
||||
self.location_dense = Linear(
|
||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
||||
def forward(self, attention_cat):
|
||||
processed_attention = self.location_conv(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(
|
||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.location_attention = location_attention
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = -1
|
||||
self.win_back = 2
|
||||
self.win_front = 6
|
||||
|
||||
def init_forward_attn(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat(
|
||||
[torch.ones([B, 1]),
|
||||
torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def init_location_attention(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
|
||||
def init_states(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
if self.location_attention:
|
||||
self.init_location_attention(inputs)
|
||||
if self.forward_attn:
|
||||
self.init_forward_attn(inputs)
|
||||
if self.windowing:
|
||||
self.init_win_idx()
|
||||
|
||||
def update_location_attention(self, alignments):
|
||||
self.attention_weights_cum += alignments
|
||||
|
||||
def get_location_attention(self, query, processed_inputs):
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights +
|
||||
processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
def get_attention(self, query, processed_inputs):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
energies = self.v(torch.tanh(processed_query + processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
def apply_windowing(self, attention, inputs):
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# this is a trick to solve a special problem.
|
||||
# but it does not hurt.
|
||||
if self.win_idx == -1:
|
||||
attention[:, 0] = attention.max()
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||
(1, 0, 0, 0)).to(inputs.device)
|
||||
# compute transition potentials
|
||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
# force incremental alignment - TODO: make configurable
|
||||
if not self.training and alignment.shape[0] == 1:
|
||||
_, n = prev_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n + 2:] = 0
|
||||
alpha[b, :(
|
||||
n - 1
|
||||
)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (
|
||||
n - 2)] = 0.01 * val # smoothing factor for the prev step
|
||||
# compute attention weights
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
||||
if self.location_attention:
|
||||
attention, processed_query = self.get_location_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
else:
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
# apply windowing - only in eval mode
|
||||
if not self.training and self.windowing:
|
||||
attention = self.apply_windowing(attention, inputs)
|
||||
# normalize attention values
|
||||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(attention, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
attention).sum(dim=1).unsqueeze(1)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
if self.location_attention:
|
||||
self.update_location_attention(alignment)
|
||||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
context, self.attention_weights = self.apply_forward_attention(
|
||||
inputs, alignment, attention_hidden_state)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
self.attention_weights = alignment
|
||||
return context
|
|
@ -1,40 +1,7 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from .attention import AttentionRNNCell
|
||||
from .common_layers import Prenet
|
||||
|
||||
|
||||
# class Prenet(nn.Module):
|
||||
# r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
# It creates as many layers as given by 'out_features'
|
||||
|
||||
# Args:
|
||||
# in_features (int): size of the input vector
|
||||
# out_features (int or list): size of each output sample.
|
||||
# If it is a list, for each value, there is created a new layer.
|
||||
# """
|
||||
|
||||
# def __init__(self, in_features, out_features=[256, 128]):
|
||||
# super(Prenet, self).__init__()
|
||||
# in_features = [in_features] + out_features[:-1]
|
||||
# self.layers = nn.ModuleList([
|
||||
# nn.Linear(in_size, out_size)
|
||||
# for (in_size, out_size) in zip(in_features, out_features)
|
||||
# ])
|
||||
# self.relu = nn.ReLU()
|
||||
# self.dropout = nn.Dropout(0.5)
|
||||
# # self.init_layers()
|
||||
|
||||
# def init_layers(self):
|
||||
# for layer in self.layers:
|
||||
# torch.nn.init.xavier_uniform_(
|
||||
# layer.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
|
||||
# def forward(self, inputs):
|
||||
# for linear in self.layers:
|
||||
# inputs = self.dropout(self.relu(linear(inputs)))
|
||||
# return inputs
|
||||
from .common_layers import Prenet, Attention
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
|
@ -319,14 +286,17 @@ class Decoder(nn.Module):
|
|||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
rnn_dim=256,
|
||||
annot_dim=in_features,
|
||||
memory_dim=128,
|
||||
align_model='ls',
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm)
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -382,6 +352,8 @@ class Decoder(nn.Module):
|
|||
# attention states
|
||||
self.attention = inputs.data.new(B, T).zero_()
|
||||
self.attention_cum = inputs.data.new(B, T).zero_()
|
||||
# cache attention inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
|
||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||
# Back to batch first
|
||||
|
@ -390,18 +362,12 @@ class Decoder(nn.Module):
|
|||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
def decode(self, inputs, t, mask=None):
|
||||
def decode(self, inputs, mask=None):
|
||||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat(
|
||||
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
|
||||
processed_memory, self.current_context_vec,
|
||||
self.attention_rnn_hidden, inputs, attention_cat, mask, t)
|
||||
del attention_cat
|
||||
self.attention_cum += self.attention
|
||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
||||
|
@ -424,7 +390,7 @@ class Decoder(nn.Module):
|
|||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
return output, stop_token, self.attention
|
||||
return output, stop_token, self.attention_layer.attention_weights
|
||||
|
||||
def _update_memory_queue(self, new_memory):
|
||||
if self.memory_size > 0:
|
||||
|
@ -456,11 +422,12 @@ class Decoder(nn.Module):
|
|||
stop_tokens = []
|
||||
t = 0
|
||||
self._init_states(inputs)
|
||||
self.attention_layer.init_states(inputs)
|
||||
while len(outputs) < memory.size(0):
|
||||
if t > 0:
|
||||
new_memory = memory[t - 1]
|
||||
self._update_memory_queue(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, t, mask)
|
||||
output, stop_token, attention = self.decode(inputs, mask)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
|
@ -481,6 +448,8 @@ class Decoder(nn.Module):
|
|||
stop_tokens = []
|
||||
t = 0
|
||||
self._init_states(inputs)
|
||||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_states(inputs)
|
||||
while True:
|
||||
if t > 0:
|
||||
new_memory = outputs[-1]
|
||||
|
|
|
@ -3,83 +3,7 @@ import torch
|
|||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class LinearBN(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_features)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear_layer(x)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(1, 2, 0)
|
||||
out = self.bn(out)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(2, 0, 1)
|
||||
return out
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[256, 256]):
|
||||
super(Prenet, self).__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "original":
|
||||
self.layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
else:
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
from .common_layers import Attention, Prenet, Linear, LinearBN
|
||||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
|
@ -103,178 +27,6 @@ class ConvBNBlock(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
||||
attention_dim):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=31,
|
||||
stride=1,
|
||||
padding=(31 - 1) // 2,
|
||||
bias=False)
|
||||
self.location_dense = Linear(
|
||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
||||
def forward(self, attention_cat):
|
||||
processed_attention = self.location_conv(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(
|
||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.location_attention = location_attention
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = -1
|
||||
self.win_back = 2
|
||||
self.win_front = 6
|
||||
|
||||
def init_forward_attn(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat(
|
||||
[torch.ones([B, 1]),
|
||||
torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def init_location_attention(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
|
||||
def init_states(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
if self.location_attention:
|
||||
self.init_location_attention(inputs)
|
||||
if self.forward_attn:
|
||||
self.init_forward_attn(inputs)
|
||||
if self.windowing:
|
||||
self.init_win_idx()
|
||||
|
||||
def update_location_attention(self, alignments):
|
||||
self.attention_weights_cum += alignments
|
||||
|
||||
def get_location_attention(self, query, processed_inputs):
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights +
|
||||
processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
def get_attention(self, query, processed_inputs):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
energies = self.v(torch.tanh(processed_query + processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
def apply_windowing(self, attention, inputs):
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# this is a trick to solve a special problem.
|
||||
# but it does not hurt.
|
||||
if self.win_idx == -1:
|
||||
attention[:, 0] = attention.max()
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||
(1, 0, 0, 0)).to(inputs.device)
|
||||
# compute transition potentials
|
||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
# force incremental alignment - TODO: make configurable
|
||||
if not self.training and alignment.shape[0] == 1:
|
||||
_, n = prev_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n+2:] = 0
|
||||
alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
|
||||
# compute attention weights
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs, mask):
|
||||
if self.location_attention:
|
||||
attention, processed_query = self.get_location_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
else:
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
# apply windowing - only in eval mode
|
||||
if not self.training and self.windowing:
|
||||
attention = self.apply_windowing(attention, inputs)
|
||||
# normalize attention values
|
||||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(attention, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
attention).sum(dim=1).unsqueeze(1)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
if self.location_attention:
|
||||
self.update_location_attention(alignment)
|
||||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
context, self.attention_weights = self.apply_forward_attention(
|
||||
inputs, alignment, attention_hidden_state)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
self.attention_weights = alignment
|
||||
return context
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, mel_dim, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
|
@ -494,7 +246,7 @@ class Decoder(nn.Module):
|
|||
self.attention_layer.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False, False]
|
||||
stop_flags = [True, False, False]
|
||||
stop_count = 0
|
||||
while True:
|
||||
memory = self.prenet(memory)
|
||||
|
@ -510,7 +262,7 @@ class Decoder(nn.Module):
|
|||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
stop_count += 1
|
||||
if stop_count > 10:
|
||||
if stop_count > 20:
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -537,7 +289,7 @@ class Decoder(nn.Module):
|
|||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_states(inputs)
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False, False]
|
||||
stop_flags = [True, False, False]
|
||||
stop_count = 0
|
||||
while True:
|
||||
memory = self.prenet(self.memory_truncated)
|
||||
|
@ -548,12 +300,12 @@ class Decoder(nn.Module):
|
|||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
||||
and t > inputs.shape[1])
|
||||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
stop_count += 1
|
||||
if stop_count > 2:
|
||||
if stop_count > 20:
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
|
Loading…
Reference in New Issue