From b031a656778867b7ef28fafe59ee190a13c4dbf5 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 6 Mar 2019 13:14:58 +0100 Subject: [PATCH] compute sequence mask in model, add tacotron2 relatedfiles --- layers/tacotron2.py | 385 +++++++++++++++++++++++++++++++++++++++ models/tacotron.py | 10 +- models/tacotron2.py | 51 ++++++ tests/tacotron2_tests.py | 69 +++++++ 4 files changed, 508 insertions(+), 7 deletions(-) create mode 100644 layers/tacotron2.py create mode 100644 models/tacotron2.py create mode 100644 tests/tacotron2_tests.py diff --git a/layers/tacotron2.py b/layers/tacotron2.py new file mode 100644 index 00000000..296ea7ec --- /dev/null +++ b/layers/tacotron2.py @@ -0,0 +1,385 @@ +from math import sqrt +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + + +class Linear(nn.Module): + def __init__(self, + in_features, + out_features, + bias=True, + init_gain='linear'): + super(Linear, self).__init__() + self.linear_layer = torch.nn.Linear( + in_features, out_features, bias=bias) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class Prenet(nn.Module): + def __init__(self, in_features, out_features=[256, 256]): + super(Prenet, self).__init__() + in_features = [in_features] + out_features[:-1] + self.layers = nn.ModuleList([ + Linear(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_features, out_features) + ]) + + def forward(self, x): + for linear in self.layers: + # Prenet uses dropout also at inference time. Otherwise, + # it degrades the inference time attention. + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class ConvBNBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None): + super(ConvBNBlock, self).__init__() + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + conv1d = nn.Conv1d( + in_channels, out_channels, kernel_size, padding=padding) + norm = nn.BatchNorm1d(out_channels) + dropout = nn.Dropout(p=0.5) + if nonlinear == 'relu': + self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout) + elif nonlinear == 'tanh': + self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout) + else: + self.net = nn.Sequential(conv1d, norm, dropout) + + def forward(self, x): + output = self.net(x) + return output + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + self.location_conv = nn.Conv1d( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=31, + stride=1, + padding=(31 - 1) // 2, + bias=False) + self.location_dense = Linear( + attention_n_filters, attention_dim, bias=False, init_gain='tanh') + + def forward(self, attention_cat): + processed_attention = self.location_conv(attention_cat) + processed_attention = self.location_dense( + processed_attention.transpose(1, 2)) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size, + windowing): + super(Attention, self).__init__() + self.query_layer = Linear( + attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') + self.inputs_layer = Linear( + embedding_dim, attention_dim, bias=False, init_gain='tanh') + self.v = Linear(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self._mask_value = -float("inf") + self.windowing = windowing + if self.windowing: + self.win_back = 1 + self.win_front = 3 + self.win_idx = None + + def init_win_idx(self): + self.win_idx = 0 + + def get_attention(self, query, processed_inputs, attention_cat): + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + + processed_inputs)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, inputs, processed_inputs, + attention_cat, mask): + attention = self.get_attention( + attention_hidden_state, processed_inputs, attention_cat) + + if mask is not None: + attention.data.masked_fill_(1 - mask, self._mask_value) + # Windowing + if not self.training and self.windowing: + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + alignment = torch.sigmoid(attention) / torch.sigmoid( + attention).sum(dim=1).unsqueeze(1) + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + return context, alignment + + +class Postnet(nn.Module): + def __init__(self, mel_dim, num_convs=5): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + self.convolutions.append( + ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh')) + for i in range(1, num_convs - 1): + self.convolutions.append( + ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh')) + self.convolutions.append( + ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None)) + + def forward(self, x): + for layer in self.convolutions: + x = layer(x) + return x + + +class Encoder(nn.Module): + def __init__(self, in_features=512): + super(Encoder, self).__init__() + convolutions = [] + for _ in range(3): + convolutions.append( + ConvBNBlock(in_features, in_features, 5, 'relu')) + self.convolutions = nn.Sequential(*convolutions) + self.lstm = nn.LSTM( + in_features, + int(in_features / 2), + num_layers=1, + batch_first=True, + bidirectional=True) + + def forward(self, x, input_lengths): + x = self.convolutions(x) + x = x.transpose(1, 2) + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True) + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence( + outputs, + batch_first=True, + ) + return outputs + + def inference(self, x): + x = self.convolutions(x) + x = x.transpose(1, 2) + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + return outputs + +# adapted from https://github.com/NVIDIA/tacotron2/ +class Decoder(nn.Module): + def __init__(self, in_features, inputs_dim, r, attn_win): + super(Decoder, self).__init__() + self.mel_channels = inputs_dim + self.r = r + self.encoder_embedding_dim = in_features + self.attention_rnn_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.max_decoder_steps = 1000 + self.gate_threshold = 0.5 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + self.prenet = Prenet(self.mel_channels * r, + [self.prenet_dim, self.prenet_dim]) + + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.attention_rnn_dim) + + self.attention_layer = Attention(self.attention_rnn_dim, in_features, + 128, 32, 31, attn_win) + + self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, + self.decoder_rnn_dim, 1) + + self.linear_projection = Linear(self.decoder_rnn_dim + in_features, + self.mel_channels * r) + + self.stopnet = nn.Sequential( + nn.Dropout(0.1), + Linear(self.decoder_rnn_dim + self.mel_channels * r, + 1, + bias=True, + init_gain='sigmoid')) + + self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) + self.go_frame_init = nn.Embedding(1, self.mel_channels * r) + self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) + + def get_go_frame(self, inputs): + B = inputs.size(0) + memory = self.go_frame_init(inputs.data.new_zeros(B).long()) + return memory + + def _init_states(self, inputs, mask): + B = inputs.size(0) + T = inputs.size(1) + + self.attention_hidden = self.attention_rnn_init( + inputs.data.new_zeros(B).long()) + self.attention_cell = Variable( + inputs.data.new(B, self.attention_rnn_dim).zero_()) + + self.decoder_hidden = self.decoder_rnn_inits( + inputs.data.new_zeros(B).long()) + self.decoder_cell = Variable( + inputs.data.new(B, self.decoder_rnn_dim).zero_()) + + self.attention_weights = Variable(inputs.data.new(B, T).zero_()) + self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) + self.context = Variable( + inputs.data.new(B, self.encoder_embedding_dim).zero_()) + + self.inputs = inputs + self.processed_inputs = self.attention_layer.inputs_layer(inputs) + self.mask = mask + + def _reshape_memory(self, memories): + memories = memories.view( + memories.size(0), int(memories.size(1) / self.r), -1) + memories = memories.transpose(0, 1) + return memories + + def _parse_outputs(self, outputs, gate_outputs, alignments): + alignments = torch.stack(alignments).transpose(0, 1) + gate_outputs = torch.stack(gate_outputs).transpose(0, 1) + gate_outputs = gate_outputs.contiguous() + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view( + outputs.size(0), -1, self.mel_channels) + outputs = outputs.transpose(1, 2) + return outputs, gate_outputs, alignments + + def decode(self, memory): + cell_input = torch.cat((memory, self.context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.p_attention_dropout, self.training) + self.attention_cell = F.dropout( + self.attention_cell, self.p_attention_dropout, self.training) + + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), + dim=1) + self.context, self.attention_weights = self.attention_layer( + self.attention_hidden, self.inputs, self.processed_inputs, + attention_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + memory = torch.cat( + (self.attention_hidden, self.context), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + memory, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout(self.decoder_hidden, + self.p_decoder_dropout, self.training) + self.decoder_cell = F.dropout(self.decoder_cell, + self.p_decoder_dropout, self.training) + + decoder_hidden_context = torch.cat( + (self.decoder_hidden, self.context), dim=1) + + decoder_output = self.linear_projection( + decoder_hidden_context) + + stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) + + gate_prediction = self.stopnet(stopnet_input) + return decoder_output, gate_prediction, self.attention_weights + + def forward(self, inputs, memories, mask): + memory = self.get_go_frame(inputs).unsqueeze(0) + memories = self._reshape_memory(memories) + memories = torch.cat((memory, memories), dim=0) + memories = self.prenet(memories) + + self._init_states(inputs, mask=mask) + + outputs, gate_outputs, alignments = [], [], [] + while len(outputs) < memories.size(0) - 1: + memory = memories[len(outputs)] + mel_output, gate_output, attention_weights = self.decode( + memory) + outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze(1)] + alignments += [attention_weights] + + outputs, gate_outputs, alignments = self._parse_outputs( + outputs, gate_outputs, alignments) + + return outputs, gate_outputs, alignments + + def inference(self, inputs): + memory = self.get_go_frame(inputs) + self._init_states(inputs, mask=None) + + self.attention_layer.init_win_idx() + outputs, gate_outputs, alignments, t = [], [], [], 0 + stop_flags = [False, False] + while True: + memory = self.prenet(memory) + mel_output, gate_output, alignment = self.decode(memory) + gate_output = torch.sigmoid(gate_output.data) + outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] + alignments += [alignment] + + stop_flags[0] = stop_flags[0] or gate_output > 0.5 + stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5 + if all(stop_flags): + break + elif len(outputs) == self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + + memory = mel_output + t += 1 + + outputs, gate_outputs, alignments = self._parse_outputs( + outputs, gate_outputs, alignments) + + return outputs, gate_outputs, alignments + + def inference_step(self, inputs, t, memory=None): + """ + For debug purposes + """ + if t == 0: + memory = self.get_go_frame(inputs) + self._init_states(inputs, mask=None) + + memory = self.prenet(memory) + mel_output, gate_output, alignment = self.decode(memory) + gate_output = torch.sigmoid(gate_output.data) + memory = mel_output + return mel_output, gate_output, alignment diff --git a/models/tacotron.py b/models/tacotron.py index 0192c074..3ecd3f9e 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -3,6 +3,7 @@ import torch from torch import nn from math import sqrt from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG +from utils.generic_utils import sequence_mask class Tacotron(nn.Module): @@ -27,15 +28,13 @@ class Tacotron(nn.Module): nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Sigmoid()) - def forward(self, characters, mel_specs=None, mask=None): + def forward(self, characters, text_lengths, mel_specs=None): B = characters.size(0) + mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) - # batch x time x dim encoder_outputs = self.encoder(inputs) - # batch x time x dim*r mel_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, mask) - # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) @@ -44,12 +43,9 @@ class Tacotron(nn.Module): def inference(self, characters): B = characters.size(0) inputs = self.embedding(characters) - # batch x time x dim encoder_outputs = self.encoder(inputs) - # batch x time x dim*r mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) - # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) diff --git a/models/tacotron2.py b/models/tacotron2.py new file mode 100644 index 00000000..70bdd89a --- /dev/null +++ b/models/tacotron2.py @@ -0,0 +1,51 @@ +from math import sqrt +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F +from layers.tacotron2 import Encoder, Decoder, Postnet +from utils.generic_utils import sequence_mask + + +# TODO: match function arguments with tacotron +class Tacotron2(nn.Module): + def __init__(self, num_chars, r, attn_win=False): + super(Tacotron2, self).__init__() + self.n_mel_channels = 80 + self.n_frames_per_step = r + self.embedding = nn.Embedding(num_chars, 512) + std = sqrt(2.0 / (num_chars + 512)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) + self.encoder = Encoder(512) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win) + self.postnet = Postnet(self.n_mel_channels) + + def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): + mel_outputs = mel_outputs.transpose(1, 2) + mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) + return mel_outputs, mel_outputs_postnet, alignments + + def forward(self, text, text_lengths, mel_specs=None): + # compute mask for padding + mask = sequence_mask(text_lengths).to(characters.device) + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + mel_outputs, stop_tokens, alignments = self.decoder( + encoder_outputs, mel_specs, mask) + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( + mel_outputs, mel_outputs_postnet, alignments) + return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + + def inference(self, text): + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + mel_outputs, stop_tokens, alignments = self.decoder.inference( + encoder_outputs) + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( + mel_outputs, mel_outputs_postnet, alignments) + return mel_outputs, mel_outputs_postnet, alignments, stop_tokens \ No newline at end of file diff --git a/tests/tacotron2_tests.py b/tests/tacotron2_tests.py new file mode 100644 index 00000000..56c5a1a1 --- /dev/null +++ b/tests/tacotron2_tests.py @@ -0,0 +1,69 @@ +import os +import copy +import torch +import unittest +import numpy as np + +from torch import optim +from torch import nn +from utils.generic_utils import load_config +from layers.losses import MSELossMasked +from models.tacotron2 import Tacotron2 + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +file_path = os.path.dirname(os.path.realpath(__file__)) +c = load_config(os.path.join(file_path, 'test_config.json')) + + +class TacotronTrainTest(unittest.TestCase): + def test_train_step(self): + input = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked().to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(24, c.r).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(5): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input, input_lengths, mel_spec) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 \ No newline at end of file