diff --git a/config_tacotron.json b/config_tacotron.json new file mode 100644 index 00000000..968eae1e --- /dev/null +++ b/config_tacotron.json @@ -0,0 +1,79 @@ +{ + "run_name": "mozilla-tacotron-tagent", + "run_description": "using forward attention with transition agent, with original prenet, loss masking, separate stopnet, sigmoid norm. Compare this with 4841", + + "audio":{ + // Audio processing parameters + "num_mels": 80, // size of the mel spec frame. + "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "frame_length_ms": 50, // stft window length in ms. + "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": false, // move normalization to range [-1, 1] + "max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + }, + + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], + + "model": "Tacotron", // one of the model in models/ + "grad_clip": 1, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "windowing": false, // Enables attention windowing. Used only in eval mode. + "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. + "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". + "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. + "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "stopnet": true, // Train stopnet predicting the end of synthesis. + "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "eval_batch_size":16, + "r": 5, // Number of frames to predict for step. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 10, // Number of steps to log traning on console. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + + "run_eval": true, + "test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time. + "data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument + "meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader. + "meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader. + "dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py + "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 150, // DATASET-RELATED: maximum text length + "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + "text_cleaner": "phoneme_cleaners" + } + \ No newline at end of file diff --git a/layers/attention.py b/layers/attention.py deleted file mode 100644 index 08765e70..00000000 --- a/layers/attention.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F -from utils.generic_utils import sequence_mask - - -class BahdanauAttention(nn.Module): - def __init__(self, annot_dim, query_dim, attn_dim): - super(BahdanauAttention, self).__init__() - self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) - self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) - self.v = nn.Linear(attn_dim, 1, bias=False) - - def forward(self, annots, query): - """ - Shapes: - - annots: (batch, max_time, dim) - - query: (batch, 1, dim) or (batch, dim) - """ - if query.dim() == 2: - # insert time-axis for broadcasting - query = query.unsqueeze(1) - # (batch, 1, dim) - processed_query = self.query_layer(query) - processed_annots = self.annot_layer(annots) - # (batch, max_time, 1) - alignment = self.v(torch.tanh(processed_query + processed_annots)) - # (batch, max_time) - return alignment.squeeze(-1) - - -class LocationSensitiveAttention(nn.Module): - """Location sensitive attention following - https://arxiv.org/pdf/1506.07503.pdf""" - - def __init__(self, - annot_dim, - query_dim, - attn_dim, - kernel_size=31, - filters=32): - super(LocationSensitiveAttention, self).__init__() - self.kernel_size = kernel_size - self.filters = filters - padding = [(kernel_size - 1) // 2, (kernel_size - 1) // 2] - self.loc_conv = nn.Sequential( - nn.ConstantPad1d(padding, 0), - nn.Conv1d( - 2, - filters, - kernel_size=kernel_size, - stride=1, - padding=0, - bias=False)) - self.loc_linear = nn.Linear(filters, attn_dim, bias=True) - self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) - self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) - self.v = nn.Linear(attn_dim, 1, bias=False) - self.processed_annots = None - # self.init_layers() - - def init_layers(self): - torch.nn.init.xavier_uniform_( - self.loc_linear.weight, - gain=torch.nn.init.calculate_gain('tanh')) - torch.nn.init.xavier_uniform_( - self.query_layer.weight, - gain=torch.nn.init.calculate_gain('tanh')) - torch.nn.init.xavier_uniform_( - self.annot_layer.weight, - gain=torch.nn.init.calculate_gain('tanh')) - torch.nn.init.xavier_uniform_( - self.v.weight, - gain=torch.nn.init.calculate_gain('linear')) - - def reset(self): - self.processed_annots = None - - def forward(self, annot, query, loc): - """ - Shapes: - - annot: (batch, max_time, dim) - - query: (batch, 1, dim) or (batch, dim) - - loc: (batch, 2, max_time) - """ - if query.dim() == 2: - # insert time-axis for broadcasting - query = query.unsqueeze(1) - processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2)) - processed_query = self.query_layer(query) - # cache annots - if self.processed_annots is None: - self.processed_annots = self.annot_layer(annot) - alignment = self.v( - torch.tanh(processed_query + self.processed_annots + processed_loc)) - del processed_loc - del processed_query - # (batch, max_time) - return alignment.squeeze(-1) - - -class AttentionRNNCell(nn.Module): - def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"): - r""" - General Attention RNN wrapper - - Args: - out_dim (int): context vector feature dimension. - rnn_dim (int): rnn hidden state dimension. - annot_dim (int): annotation vector feature dimension. - memory_dim (int): memory vector (decoder output) feature dimension. - align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. - windowing (bool): attention windowing forcing monotonic attention. - It is only active in eval mode. - norm (str): norm method to compute alignment weights. - """ - super(AttentionRNNCell, self).__init__() - self.align_model = align_model - self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim) - self.windowing = windowing - if self.windowing: - self.win_back = 3 - self.win_front = 6 - self.win_idx = None - self.norm = norm - if align_model == 'b': - self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, - out_dim) - if align_model == 'ls': - self.alignment_model = LocationSensitiveAttention( - annot_dim, rnn_dim, out_dim) - else: - raise RuntimeError(" Wrong alignment model name: {}. Use\ - 'b' (Bahdanau) or 'ls' (Location Sensitive).".format( - align_model)) - - def forward(self, memory, context, rnn_state, annots, atten, mask, t): - """ - Shapes: - - memory: (batch, 1, dim) or (batch, dim) - - context: (batch, dim) - - rnn_state: (batch, out_dim) - - annots: (batch, max_time, annot_dim) - - atten: (batch, 2, max_time) - - mask: (batch,) - """ - if t == 0: - self.alignment_model.reset() - self.win_idx = 0 - rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state) - if self.align_model is 'b': - alignment = self.alignment_model(annots, rnn_output) - else: - alignment = self.alignment_model(annots, rnn_output, atten) - if mask is not None: - mask = mask.view(memory.size(0), -1) - alignment.masked_fill_(1 - mask, -float("inf")) - # Windowing - if not self.training and self.windowing: - back_win = self.win_idx - self.win_back - front_win = self.win_idx + self.win_front - if back_win > 0: - alignment[:, :back_win] = -float("inf") - if front_win < memory.shape[1]: - alignment[:, front_win:] = -float("inf") - # Update the window - self.win_idx = torch.argmax(alignment,1).long()[0].item() - if self.norm == "softmax": - alignment = torch.softmax(alignment, dim=-1) - elif self.norm == "sigmoid": - alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) - else: - raise RuntimeError("Unknown value for attention norm type") - context = torch.bmm(alignment.unsqueeze(1), annots) - context = context.squeeze(1) - return rnn_output, context, alignment diff --git a/layers/common_layers.py b/layers/common_layers.py index c5704f62..c15c0b10 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -80,4 +80,179 @@ class Prenet(nn.Module): x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) else: x = F.relu(linear(x)) - return x \ No newline at end of file + return x + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + self.location_conv = nn.Conv1d( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=31, + stride=1, + padding=(31 - 1) // 2, + bias=False) + self.location_dense = Linear( + attention_n_filters, attention_dim, bias=False, init_gain='tanh') + + def forward(self, attention_cat): + processed_attention = self.location_conv(attention_cat) + processed_attention = self.location_dense( + processed_attention.transpose(1, 2)) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + location_attention, attention_location_n_filters, + attention_location_kernel_size, windowing, norm, forward_attn, + trans_agent): + super(Attention, self).__init__() + self.query_layer = Linear( + attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') + self.inputs_layer = Linear( + embedding_dim, attention_dim, bias=False, init_gain='tanh') + self.v = Linear(attention_dim, 1, bias=True) + if trans_agent: + self.ta = nn.Linear( + attention_rnn_dim + embedding_dim, 1, bias=True) + if location_attention: + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, + attention_dim) + self._mask_value = -float("inf") + self.windowing = windowing + self.win_idx = None + self.norm = norm + self.forward_attn = forward_attn + self.trans_agent = trans_agent + self.location_attention = location_attention + + def init_win_idx(self): + self.win_idx = -1 + self.win_back = 2 + self.win_front = 6 + + def init_forward_attn(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.alpha = torch.cat( + [torch.ones([B, 1]), + torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) + + def init_location_attention(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) + + def init_states(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights = Variable(inputs.data.new(B, T).zero_()) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), + dim=1) + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v(torch.tanh(processed_query + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def apply_windowing(self, attention, inputs): + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + return attention + + def apply_forward_attention(self, inputs, alignment, query): + # forward attention + prev_alpha = F.pad(self.alpha[:, :-1].clone(), + (1, 0, 0, 0)).to(inputs.device) + # compute transition potentials + alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + + self.u * prev_alpha) + 1e-8) * alignment + # force incremental alignment - TODO: make configurable + if not self.training and alignment.shape[0] == 1: + _, n = prev_alpha.max(1) + val, n2 = alpha.max(1) + for b in range(alignment.shape[0]): + alpha[b, n + 2:] = 0 + alpha[b, :( + n - 1 + )] = 0 # ignore all previous states to prevent repetition. + alpha[b, ( + n - 2)] = 0.01 * val # smoothing factor for the prev step + # compute attention weights + self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) + # compute context + context = torch.bmm(self.alpha.unsqueeze(1), inputs) + context = context.squeeze(1) + # compute transition agent + if self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) + return context, self.alpha + + def forward(self, attention_hidden_state, inputs, processed_inputs, mask): + if self.location_attention: + attention, processed_query = self.get_location_attention( + attention_hidden_state, processed_inputs) + else: + attention, processed_query = self.get_attention( + attention_hidden_state, processed_inputs) + # apply masking + if mask is not None: + attention.data.masked_fill_(1 - mask, self._mask_value) + # apply windowing - only in eval mode + if not self.training and self.windowing: + attention = self.apply_windowing(attention, inputs) + # normalize attention values + if self.norm == "softmax": + alignment = torch.softmax(attention, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(attention) / torch.sigmoid( + attention).sum(dim=1).unsqueeze(1) + else: + raise RuntimeError("Unknown value for attention norm type") + if self.location_attention: + self.update_location_attention(alignment) + # apply forward attention if enabled + if self.forward_attn: + context, self.attention_weights = self.apply_forward_attention( + inputs, alignment, attention_hidden_state) + else: + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + return context \ No newline at end of file diff --git a/layers/tacotron.py b/layers/tacotron.py index 690407b7..125d56c7 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -1,40 +1,7 @@ # coding: utf-8 import torch from torch import nn -from .attention import AttentionRNNCell -from .common_layers import Prenet - - -# class Prenet(nn.Module): -# r""" Prenet as explained at https://arxiv.org/abs/1703.10135. -# It creates as many layers as given by 'out_features' - -# Args: -# in_features (int): size of the input vector -# out_features (int or list): size of each output sample. -# If it is a list, for each value, there is created a new layer. -# """ - -# def __init__(self, in_features, out_features=[256, 128]): -# super(Prenet, self).__init__() -# in_features = [in_features] + out_features[:-1] -# self.layers = nn.ModuleList([ -# nn.Linear(in_size, out_size) -# for (in_size, out_size) in zip(in_features, out_features) -# ]) -# self.relu = nn.ReLU() -# self.dropout = nn.Dropout(0.5) -# # self.init_layers() - -# def init_layers(self): -# for layer in self.layers: -# torch.nn.init.xavier_uniform_( -# layer.weight, gain=torch.nn.init.calculate_gain('relu')) - -# def forward(self, inputs): -# for linear in self.layers: -# inputs = self.dropout(self.relu(linear(inputs))) -# return inputs +from .common_layers import Prenet, Attention class BatchNormConv1d(nn.Module): @@ -319,14 +286,17 @@ class Decoder(nn.Module): prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = AttentionRNNCell( - out_dim=128, - rnn_dim=256, - annot_dim=in_features, - memory_dim=128, - align_model='ls', - windowing=attn_windowing, - norm=attn_norm) + self.attention_rnn = nn.GRUCell(in_features + 128, 256) + self.attention_layer = Attention(attention_rnn_dim=256, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -382,6 +352,8 @@ class Decoder(nn.Module): # attention states self.attention = inputs.data.new(B, T).zero_() self.attention_cum = inputs.data.new(B, T).zero_() + # cache attention inputs + self.processed_inputs = self.attention_layer.inputs_layer(inputs) def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first @@ -390,18 +362,12 @@ class Decoder(nn.Module): stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) return outputs, attentions, stop_tokens - def decode(self, inputs, t, mask=None): + def decode(self, inputs, mask=None): # Prenet processed_memory = self.prenet(self.memory_input) # Attention RNN - attention_cat = torch.cat( - (self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), - dim=1) - self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn( - processed_memory, self.current_context_vec, - self.attention_rnn_hidden, inputs, attention_cat, mask, t) - del attention_cat - self.attention_cum += self.attention + self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden) + self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((self.attention_rnn_hidden, self.current_context_vec), @@ -424,7 +390,7 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return output, stop_token, self.attention + return output, stop_token, self.attention_layer.attention_weights def _update_memory_queue(self, new_memory): if self.memory_size > 0: @@ -456,11 +422,12 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) + self.attention_layer.init_states(inputs) while len(outputs) < memory.size(0): if t > 0: new_memory = memory[t - 1] self._update_memory_queue(new_memory) - output, stop_token, attention = self.decode(inputs, t, mask) + output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] stop_tokens += [stop_token] @@ -481,6 +448,8 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) + self.attention_layer.init_win_idx() + self.attention_layer.init_states(inputs) while True: if t > 0: new_memory = outputs[-1] diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 4fe6c5b8..8cdc9c16 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -3,83 +3,7 @@ import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F - - -class Linear(nn.Module): - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): - super(Linear, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) - self._init_w(init_gain) - - def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) - - def forward(self, x): - return self.linear_layer(x) - - -class LinearBN(nn.Module): - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): - super(LinearBN, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) - self.bn = nn.BatchNorm1d(out_features) - self._init_w(init_gain) - - def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) - - def forward(self, x): - out = self.linear_layer(x) - if len(out.shape) == 3: - out = out.permute(1, 2, 0) - out = self.bn(out) - if len(out.shape) == 3: - out = out.permute(2, 0, 1) - return out - - -class Prenet(nn.Module): - def __init__(self, - in_features, - prenet_type, - prenet_dropout, - out_features=[256, 256]): - super(Prenet, self).__init__() - self.prenet_type = prenet_type - self.prenet_dropout = prenet_dropout - in_features = [in_features] + out_features[:-1] - if prenet_type == "bn": - self.layers = nn.ModuleList([ - LinearBN(in_size, out_size, bias=False) - for (in_size, out_size) in zip(in_features, out_features) - ]) - elif prenet_type == "original": - self.layers = nn.ModuleList([ - Linear(in_size, out_size, bias=False) - for (in_size, out_size) in zip(in_features, out_features) - ]) - - def forward(self, x): - for linear in self.layers: - if self.prenet_dropout: - x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) - else: - x = F.relu(linear(x)) - return x +from .common_layers import Attention, Prenet, Linear, LinearBN class ConvBNBlock(nn.Module): @@ -103,178 +27,6 @@ class ConvBNBlock(nn.Module): return output -class LocationLayer(nn.Module): - def __init__(self, attention_n_filters, attention_kernel_size, - attention_dim): - super(LocationLayer, self).__init__() - self.location_conv = nn.Conv1d( - in_channels=2, - out_channels=attention_n_filters, - kernel_size=31, - stride=1, - padding=(31 - 1) // 2, - bias=False) - self.location_dense = Linear( - attention_n_filters, attention_dim, bias=False, init_gain='tanh') - - def forward(self, attention_cat): - processed_attention = self.location_conv(attention_cat) - processed_attention = self.location_dense( - processed_attention.transpose(1, 2)) - return processed_attention - - -class Attention(nn.Module): - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, - location_attention, attention_location_n_filters, - attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent): - super(Attention, self).__init__() - self.query_layer = Linear( - attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') - self.inputs_layer = Linear( - embedding_dim, attention_dim, bias=False, init_gain='tanh') - self.v = Linear(attention_dim, 1, bias=True) - if trans_agent: - self.ta = nn.Linear( - attention_rnn_dim + embedding_dim, 1, bias=True) - if location_attention: - self.location_layer = LocationLayer( - attention_location_n_filters, attention_location_kernel_size, - attention_dim) - self._mask_value = -float("inf") - self.windowing = windowing - self.win_idx = None - self.norm = norm - self.forward_attn = forward_attn - self.trans_agent = trans_agent - self.location_attention = location_attention - - def init_win_idx(self): - self.win_idx = -1 - self.win_back = 2 - self.win_front = 6 - - def init_forward_attn(self, inputs): - B = inputs.shape[0] - T = inputs.shape[1] - self.alpha = torch.cat( - [torch.ones([B, 1]), - torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) - self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) - - def init_location_attention(self, inputs): - B = inputs.shape[0] - T = inputs.shape[1] - self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) - - def init_states(self, inputs): - B = inputs.shape[0] - T = inputs.shape[1] - self.attention_weights = Variable(inputs.data.new(B, T).zero_()) - if self.location_attention: - self.init_location_attention(inputs) - if self.forward_attn: - self.init_forward_attn(inputs) - if self.windowing: - self.init_win_idx() - - def update_location_attention(self, alignments): - self.attention_weights_cum += alignments - - def get_location_attention(self, query, processed_inputs): - attention_cat = torch.cat((self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), - dim=1) - processed_query = self.query_layer(query.unsqueeze(1)) - processed_attention_weights = self.location_layer(attention_cat) - energies = self.v( - torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) - energies = energies.squeeze(-1) - return energies, processed_query - - def get_attention(self, query, processed_inputs): - processed_query = self.query_layer(query.unsqueeze(1)) - energies = self.v(torch.tanh(processed_query + processed_inputs)) - energies = energies.squeeze(-1) - return energies, processed_query - - def apply_windowing(self, attention, inputs): - back_win = self.win_idx - self.win_back - front_win = self.win_idx + self.win_front - if back_win > 0: - attention[:, :back_win] = -float("inf") - if front_win < inputs.shape[1]: - attention[:, front_win:] = -float("inf") - # this is a trick to solve a special problem. - # but it does not hurt. - if self.win_idx == -1: - attention[:, 0] = attention.max() - # Update the window - self.win_idx = torch.argmax(attention, 1).long()[0].item() - return attention - - def apply_forward_attention(self, inputs, alignment, query): - # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone(), - (1, 0, 0, 0)).to(inputs.device) - # compute transition potentials - alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + - self.u * prev_alpha) + 1e-8) * alignment - # force incremental alignment - TODO: make configurable - if not self.training and alignment.shape[0] == 1: - _, n = prev_alpha.max(1) - val, n2 = alpha.max(1) - for b in range(alignment.shape[0]): - alpha[b, n+2:] = 0 - alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition. - alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step - # compute attention weights - self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) - # compute context - context = torch.bmm(self.alpha.unsqueeze(1), inputs) - context = context.squeeze(1) - # compute transition agent - if self.trans_agent: - ta_input = torch.cat([context, query.squeeze(1)], dim=-1) - self.u = torch.sigmoid(self.ta(ta_input)) - return context, self.alpha - - def forward(self, attention_hidden_state, inputs, processed_inputs, mask): - if self.location_attention: - attention, processed_query = self.get_location_attention( - attention_hidden_state, processed_inputs) - else: - attention, processed_query = self.get_attention( - attention_hidden_state, processed_inputs) - # apply masking - if mask is not None: - attention.data.masked_fill_(1 - mask, self._mask_value) - # apply windowing - only in eval mode - if not self.training and self.windowing: - attention = self.apply_windowing(attention, inputs) - # normalize attention values - if self.norm == "softmax": - alignment = torch.softmax(attention, dim=-1) - elif self.norm == "sigmoid": - alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum(dim=1).unsqueeze(1) - else: - raise RuntimeError("Unknown value for attention norm type") - if self.location_attention: - self.update_location_attention(alignment) - # apply forward attention if enabled - if self.forward_attn: - context, self.attention_weights = self.apply_forward_attention( - inputs, alignment, attention_hidden_state) - else: - context = torch.bmm(alignment.unsqueeze(1), inputs) - context = context.squeeze(1) - self.attention_weights = alignment - return context - - class Postnet(nn.Module): def __init__(self, mel_dim, num_convs=5): super(Postnet, self).__init__() @@ -494,7 +246,7 @@ class Decoder(nn.Module): self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 - stop_flags = [False, False, False] + stop_flags = [True, False, False] stop_count = 0 while True: memory = self.prenet(memory) @@ -510,7 +262,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 - if stop_count > 10: + if stop_count > 20: break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -537,7 +289,7 @@ class Decoder(nn.Module): self.attention_layer.init_win_idx() self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 - stop_flags = [False, False, False] + stop_flags = [True, False, False] stop_count = 0 while True: memory = self.prenet(self.memory_truncated) @@ -548,12 +300,12 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1]) stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 - if stop_count > 2: + if stop_count > 20: break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps")