From e83a4b07d2558a25b0ad84dea642385a61660bd6 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 28 Oct 2019 14:51:19 +0100 Subject: [PATCH] commention model outputs for tacotron, align outputs shapes of tacotron and tracotron2, merge bidirectional decoder --- config.json | 4 +- layers/tacotron.py | 104 ++++++++++++++++++++--------------------- layers/tacotron2.py | 45 ++++++++++-------- models/tacotron.py | 76 +++++++++++++++++++++--------- models/tacotron2.py | 51 +++++++++++++++----- train.py | 28 +++++++++-- utils/generic_utils.py | 12 +++-- 7 files changed, 207 insertions(+), 113 deletions(-) diff --git a/config.json b/config.json index 1226e1ac..47308f4f 100644 --- a/config.json +++ b/config.json @@ -46,6 +46,7 @@ "forward_attn_mask": false, "transition_agent": false, // enable/disable transition agent of forward attention. "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. + "bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. "loss_masking": true, // enable / disable loss masking against the sequence padding. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "stopnet": true, // Train stopnet predicting the end of synthesis. @@ -82,7 +83,8 @@ [ { "name": "ljspeech", - "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/", + // "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/", + "path": "/home/erogol/Data/LJSpeech-1.1", "meta_file_train": "metadata_train.csv", "meta_file_val": "metadata_val.csv" } diff --git a/layers/tacotron.py b/layers/tacotron.py index 657eefe7..27693ed0 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -1,7 +1,7 @@ # coding: utf-8 import torch from torch import nn -from .common_layers import Prenet, Attention +from .common_layers import Prenet, Attention, Linear class BatchNormConv1d(nn.Module): @@ -125,13 +125,12 @@ class CBHG(nn.Module): # list of conv1d bank with filter size k=1...K # TODO: try dilational layers instead self.conv1d_banks = nn.ModuleList([ - BatchNormConv1d( - in_features, - conv_bank_features, - kernel_size=k, - stride=1, - padding=[(k - 1) // 2, k // 2], - activation=self.relu) for k in range(1, K + 1) + BatchNormConv1d(in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu) for k in range(1, K + 1) ]) # max pooling of conv bank, with padding # TODO: try average pooling OR larger kernel size @@ -142,39 +141,33 @@ class CBHG(nn.Module): layer_set = [] for (in_size, out_size, ac) in zip(out_features, conv_projections, activations): - layer = BatchNormConv1d( - in_size, - out_size, - kernel_size=3, - stride=1, - padding=[1, 1], - activation=ac) + layer = BatchNormConv1d(in_size, + out_size, + kernel_size=3, + stride=1, + padding=[1, 1], + activation=ac) layer_set.append(layer) self.conv1d_projections = nn.ModuleList(layer_set) # setup Highway layers if self.highway_features != conv_projections[-1]: - self.pre_highway = nn.Linear( - conv_projections[-1], highway_features, bias=False) + self.pre_highway = nn.Linear(conv_projections[-1], + highway_features, + bias=False) self.highways = nn.ModuleList([ Highway(highway_features, highway_features) for _ in range(num_highways) ]) # bi-directional GPU layer - self.gru = nn.GRU( - gru_features, - gru_features, - 1, - batch_first=True, - bidirectional=True) + self.gru = nn.GRU(gru_features, + gru_features, + 1, + batch_first=True, + bidirectional=True) def forward(self, inputs): - # (B, T_in, in_features) - x = inputs - # Needed to perform conv1d on time-axis # (B, in_features, T_in) - if x.size(-1) == self.in_features: - x = x.transpose(1, 2) - # T = x.size(-1) + x = inputs # (B, hid_features*K, T_in) # Concat conv1d bank outputs outs = [] @@ -185,10 +178,8 @@ class CBHG(nn.Module): assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) for conv1d in self.conv1d_projections: x = conv1d(x) - # (B, T_in, hid_feature) - x = x.transpose(1, 2) - # Back to the original shape x += inputs + x = x.transpose(1, 2) if self.highway_features != self.conv_projections[-1]: x = self.pre_highway(x) # Residual connection @@ -236,8 +227,10 @@ class Encoder(nn.Module): - inputs: batch x time x in_features - outputs: batch x time x 128*2 """ - inputs = self.prenet(inputs) - return self.cbhg(inputs) + # B x T x prenet_dim + outputs = self.prenet(inputs) + outputs = self.cbhg(outputs.transpose(1, 2)) + return outputs class PostCBHG(nn.Module): @@ -314,7 +307,12 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init) # learn init values instead of zero init. - self.stopnet = StopNet(256 + memory_dim * self.r_init) + self.stopnet = nn.Sequential( + nn.Dropout(0.1), + Linear(256 + memory_dim * self.r_init, + 1, + bias=True, + init_gain='sigmoid')) def set_r(self, new_r): self.r = new_r @@ -356,8 +354,9 @@ class Decoder(nn.Module): def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first attentions = torch.stack(attentions).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) + outputs = outputs.view(outputs.size(0), self.memory_dim, -1) return outputs, attentions, stop_tokens def decode(self, inputs, mask=None): @@ -438,9 +437,8 @@ class Decoder(nn.Module): output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] - stop_tokens += [stop_token] + stop_tokens += [stop_token.squeeze(1)] t += 1 - return self._parse_outputs(outputs, attentions, stop_tokens) def inference(self, inputs, speaker_embeddings=None): @@ -481,20 +479,20 @@ class Decoder(nn.Module): return self._parse_outputs(outputs, attentions, stop_tokens) -class StopNet(nn.Module): - r""" - Args: - in_features (int): feature dimension of input. - """ +# class StopNet(nn.Module): +# r""" +# Args: +# in_features (int): feature dimension of input. +# """ - def __init__(self, in_features): - super(StopNet, self).__init__() - self.dropout = nn.Dropout(0.1) - self.linear = nn.Linear(in_features, 1) - torch.nn.init.xavier_uniform_( - self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) +# def __init__(self, in_features): +# super(StopNet, self).__init__() +# self.dropout = nn.Dropout(0.1) +# self.linear = nn.Linear(in_features, 1) +# torch.nn.init.xavier_uniform_( +# self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) - def forward(self, inputs): - outputs = self.dropout(inputs) - outputs = self.linear(outputs) - return outputs +# def forward(self, inputs): +# outputs = self.dropout(inputs) +# outputs = self.linear(outputs) +# return outputs diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 0d7472fd..e0c38e30 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -98,11 +98,12 @@ class Encoder(nn.Module): class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, + def __init__(self, in_features, memory_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, separate_stopnet): + forward_attn_mask, location_attn, separate_stopnet, + speaker_embedding_dim): super(Decoder, self).__init__() - self.mel_channels = inputs_dim + self.memory_dim = memory_dim self.r_init = r self.r = r self.encoder_embedding_dim = in_features @@ -114,11 +115,15 @@ class Decoder(nn.Module): self.gate_threshold = 0.5 self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(self.mel_channels, - prenet_type, - prenet_dropout, - [self.prenet_dim, self.prenet_dim], - bias=False) + + # memory -> |Prenet| -> processed_memory + prenet_dim = self.memory_dim + speaker_embedding_dim + self.prenet = Prenet( + prenet_dim, + prenet_type, + prenet_dropout, + out_features=[self.prenet_dim, self.prenet_dim], + bias=False) self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.query_dim) @@ -139,11 +144,11 @@ class Decoder(nn.Module): self.decoder_rnn_dim, 1) self.linear_projection = Linear(self.decoder_rnn_dim + in_features, - self.mel_channels * self.r_init) + self.memory_dim * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init, + Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init, 1, bias=True, init_gain='sigmoid')) @@ -155,7 +160,7 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) memory = torch.zeros(1, device=inputs.device).repeat(B, - self.mel_channels * self.r) + self.memory_dim * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): @@ -185,16 +190,14 @@ class Decoder(nn.Module): def _parse_outputs(self, outputs, stop_tokens, alignments): alignments = torch.stack(alignments).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) - stop_tokens = stop_tokens.contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view(outputs.size(0), -1, self.mel_channels) - outputs = outputs.transpose(1, 2) + outputs = outputs.view(outputs.size(0), self.memory_dim, -1) return outputs, stop_tokens, alignments def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.mel_channels * (self.r - 1):] - return memory[:, :, self.mel_channels * (self.r - 1):] + return memory[:, self.memory_dim * (self.r - 1):] + return memory[:, :, self.memory_dim * (self.r - 1):] def decode(self, memory): query_input = torch.cat((memory, self.context), -1) @@ -228,10 +231,10 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - decoder_output = decoder_output[:, :self.r * self.mel_channels] + decoder_output = decoder_output[:, :self.r * self.memory_dim] return decoder_output, stop_token, self.attention.attention_weights - def forward(self, inputs, memories, mask): + def forward(self, inputs, memories, mask, speaker_embeddings=None): memory = self.get_go_frame(inputs).unsqueeze(0) memories = self._reshape_memory(memories) memories = torch.cat((memory, memories), dim=0) @@ -243,6 +246,8 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: memory = memories[len(outputs)] + if speaker_embeddings is not None: + memory = torch.cat([memory, speaker_embeddings], dim=-1) mel_output, stop_token, attention_weights = self.decode(memory) outputs += [mel_output.squeeze(1)] stop_tokens += [stop_token.squeeze(1)] @@ -253,7 +258,7 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments - def inference(self, inputs): + def inference(self, inputs, speaker_embeddings=None): memory = self.get_go_frame(inputs) memory = self._update_memory(memory) @@ -266,6 +271,8 @@ class Decoder(nn.Module): stop_flags = [True, False, False] while True: memory = self.prenet(memory) + if speaker_embeddings is not None: + memory = torch.cat([memory, speaker_embeddings], dim=-1) mel_output, stop_token, alignment = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) outputs += [mel_output.squeeze(1)] diff --git a/models/tacotron.py b/models/tacotron.py index 8f711364..74d28f10 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,5 +1,6 @@ # coding: utf-8 import torch +import copy from torch import nn from TTS.layers.tacotron import Encoder, Decoder, PostCBHG from TTS.utils.generic_utils import sequence_mask @@ -11,8 +12,8 @@ class Tacotron(nn.Module): num_chars, num_speakers, r=5, - linear_dim=1025, - mel_dim=80, + postnet_output_dim=1025, + decoder_output_dim=80, memory_size=5, attn_win=False, gst=False, @@ -23,28 +24,33 @@ class Tacotron(nn.Module): trans_agent=False, forward_attn_mask=False, location_attn=True, - separate_stopnet=True): + separate_stopnet=True, + bidirectional_decoder=False): super(Tacotron, self).__init__() self.r = r - self.mel_dim = mel_dim - self.linear_dim = linear_dim + self.decoder_output_dim = decoder_output_dim + self.postnet_output_dim = postnet_output_dim self.gst = gst self.num_speakers = num_speakers - self.embedding = nn.Embedding(num_chars, 256) - self.embedding.weight.data.normal_(0, 0.3) + self.bidirectional_decoder = bidirectional_decoder decoder_dim = 512 if num_speakers > 1 else 256 encoder_dim = 512 if num_speakers > 1 else 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 + # embedding layer + self.embedding = nn.Embedding(num_chars, 256) + self.embedding.weight.data.normal_(0, 0.3) # boilerplate model self.encoder = Encoder(encoder_dim) - self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, + self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, separate_stopnet, proj_speaker_dim) - self.postnet = PostCBHG(mel_dim) + if self.bidirectional_decoder: + self.decoder_backward = copy.deepcopy(self.decoder) + self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, - linear_dim) + postnet_output_dim) # speaker embedding layers if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, 256) @@ -82,27 +88,48 @@ class Tacotron(nn.Module): return inputs def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): + """ + Shapes: + - characters: B x T_in + - text_lengths: B + - mel_specs: B x T_out x D + - speaker_ids: B x 1 + """ + self._init_states() B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) + # B x T_in x embed_dim inputs = self.embedding(characters) - self._init_states() + # B x speaker_embed_dim self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: + # B x T_in x embed_dim + speaker_embed_dim inputs = self._concat_speaker_embedding(inputs, self.speaker_embeddings) + # B x T_in x encoder_dim encoder_outputs = self.encoder(inputs) if self.gst: + # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) if self.num_speakers > 1: encoder_outputs = self._concat_speaker_embedding( encoder_outputs, self.speaker_embeddings) - mel_outputs, alignments, stop_tokens = self.decoder( + # decoder_outputs: B x decoder_dim x T_out + # alignments: B x T_in x encoder_dim + # stop_tokens: B x T_in + decoder_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, mask, self.speaker_embeddings_projected) - mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - linear_outputs = self.postnet(mel_outputs) - linear_outputs = self.last_linear(linear_outputs) - return mel_outputs, linear_outputs, alignments, stop_tokens + # B x T_out x decoder_dim + postnet_outputs = self.postnet(decoder_outputs) + # B x T_out x posnet_dim + postnet_outputs = self.last_linear(postnet_outputs) + # B x T_out x decoder_dim + decoder_outputs = decoder_outputs.transpose(1, 2) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) + return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference(self, characters, speaker_ids=None, style_mel=None): B = characters.size(0) @@ -118,12 +145,19 @@ class Tacotron(nn.Module): if self.num_speakers > 1: encoder_outputs = self._concat_speaker_embedding( encoder_outputs, self.speaker_embeddings) - mel_outputs, alignments, stop_tokens = self.decoder.inference( + decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs, self.speaker_embeddings_projected) - mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - linear_outputs = self.postnet(mel_outputs) - linear_outputs = self.last_linear(linear_outputs) - return mel_outputs, linear_outputs, alignments, stop_tokens + decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = self.last_linear(postnet_outputs) + return decoder_outputs, postnet_outputs, alignments, stop_tokens + + def _backward_inference(self, mel_specs, encoder_outputs, mask): + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask, + self.speaker_embeddings_projected) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2) + return decoder_outputs_b, alignments_b def _compute_speaker_embedding(self, speaker_ids): speaker_embeddings = self.speaker_embedding(speaker_ids) diff --git a/models/tacotron2.py b/models/tacotron2.py index a91d6e2e..9f67335d 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -1,3 +1,5 @@ +import copy +import torch from math import sqrt from torch import nn from TTS.layers.tacotron2 import Encoder, Decoder, Postnet @@ -10,6 +12,8 @@ class Tacotron2(nn.Module): num_chars, num_speakers, r, + postnet_output_dim=80, + decoder_output_dim=80, attn_win=False, attn_norm="softmax", prenet_type="original", @@ -18,10 +22,16 @@ class Tacotron2(nn.Module): trans_agent=False, forward_attn_mask=False, location_attn=True, - separate_stopnet=True): + separate_stopnet=True, + bidirectional_decoder=False): super(Tacotron2, self).__init__() - self.n_mel_channels = 80 + self.decoder_output_dim = decoder_output_dim self.n_frames_per_step = r + self.bidirectional_decoder = bidirectional_decoder + decoder_dim = 512 + 256 if num_speakers > 1 else 512 + encoder_dim = 512 + 256 if num_speakers > 1 else 512 + proj_speaker_dim = 80 if num_speakers > 1 else 0 + # embedding layer self.embedding = nn.Embedding(num_chars, 512) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std @@ -29,12 +39,18 @@ class Tacotron2(nn.Module): if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, + self.encoder = Encoder(encoder_dim) + self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet) - self.postnet = Postnet(self.n_mel_channels) + location_attn, separate_stopnet, proj_speaker_dim) + if self.bidirectional_decoder: + self.decoder_backward = copy.deepcopy(self.decoder) + self.postnet = Postnet(self.decoder_output_dim) + + def _init_states(self): + self.speaker_embeddings = None + self.speaker_embeddings_projected = None @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -43,19 +59,23 @@ class Tacotron2(nn.Module): return mel_outputs, mel_outputs_postnet, alignments def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None): + self._init_states() # compute mask for padding mask = sequence_mask(text_lengths).to(text.device) embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self._add_speaker_embedding(encoder_outputs, speaker_ids) - mel_outputs, stop_tokens, alignments = self.decoder( + decoder_outputs, stop_tokens, alignments = self.decoder( encoder_outputs, mel_specs, mask) - mel_outputs_postnet = self.postnet(mel_outputs) - mel_outputs_postnet = mel_outputs + mel_outputs_postnet - mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( - mel_outputs, mel_outputs_postnet, alignments) - return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs( + decoder_outputs, postnet_outputs, alignments) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) + return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference(self, text, speaker_ids=None): embedded_inputs = self.embedding(text).transpose(1, 2) @@ -86,6 +106,13 @@ class Tacotron2(nn.Module): mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + def _backward_inference(self, mel_specs, encoder_outputs, mask): + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask, + self.speaker_embeddings_projected) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2) + return decoder_outputs_b, alignments_b + def _add_speaker_embedding(self, encoder_outputs, speaker_ids): if hasattr(self, "speaker_embedding") and speaker_ids is None: raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") diff --git a/train.py b/train.py index eafd2d0e..a721306d 100644 --- a/train.py +++ b/train.py @@ -88,6 +88,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, 'avg_loader_time': 0, 'avg_alignment_score': 0 } + if c.bidirectional_decoder: + train_values['avg_decoder_b_loss'] = 0 # decoder backward loss + train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss keep_avg = KeepAverage() keep_avg.add_values(train_values) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) @@ -150,8 +153,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, speaker_ids = speaker_ids.cuda(non_blocking=True) # forward pass model - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + if c.bidirectional_decoder: + decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + else: + decoder_output, postnet_output, alignments, stop_tokens = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) # loss computation stop_loss = criterion_st(stop_tokens, @@ -174,6 +181,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if not c.separate_stopnet and c.stopnet: loss += stop_loss + # backward decoder + if c.bidirectional_decoder: + if c.loss_masking: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) + else: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) + loss = decoder_backward_loss + decoder_c_loss + keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) + loss.backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip) @@ -445,7 +462,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } - tb_logger.tb_eval_figures(global_step, eval_figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: @@ -461,7 +477,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): "loss_decoder": keep_avg['avg_decoder_loss'], "stop_loss": keep_avg['avg_stop_loss'] } + + if c.bidirectional_decoder: + epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward'] + epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy() tb_logger.tb_eval_stats(global_step, epoch_stats) + tb_logger.tb_eval_figures(global_step, eval_figures) + if args.rank == 0 and epoch > c.test_delay_epochs: # test sentences diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 50d611b8..bc292edd 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -283,8 +283,8 @@ def setup_model(num_chars, num_speakers, c): model = MyModel(num_chars=num_chars, num_speakers=num_speakers, r=c.r, - linear_dim=1025, - mel_dim=80, + postnet_output_dim=c.audio['num_freq'], + decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, memory_size=c.memory_size, attn_win=c.windowing, @@ -295,11 +295,14 @@ def setup_model(num_chars, num_speakers, c): trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, - separate_stopnet=c.separate_stopnet) + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) elif c.model.lower() == "tacotron2": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, r=c.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, @@ -308,7 +311,8 @@ def setup_model(num_chars, num_speakers, c): trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, - separate_stopnet=c.separate_stopnet) + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) return model