From 9559e3fc308d70c029b1d703d1eeb06295d4bee3 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 4 Jun 2020 14:28:16 +0200 Subject: [PATCH] inherit TacotronAbstact with both tacotron and tacotron2 --- models/tacotron.py | 135 ++++++++++++++++++------------------------- models/tacotron2.py | 136 ++++++++++++++++++++++++++++---------------- 2 files changed, 142 insertions(+), 129 deletions(-) diff --git a/models/tacotron.py b/models/tacotron.py index fba82b1b..c526374a 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,23 +1,21 @@ # coding: utf-8 import torch -import copy from torch import nn -from TTS.layers.tacotron import Encoder, Decoder, PostCBHG -from TTS.utils.generic_utils import sequence_mask + from TTS.layers.gst_layers import GST +from TTS.layers.tacotron import Decoder, Encoder, PostCBHG +from TTS.models.tacotron_abstract import TacotronAbstract -class Tacotron(nn.Module): +class Tacotron(TacotronAbstract): def __init__(self, num_chars, num_speakers, r=5, postnet_output_dim=1025, decoder_output_dim=80, - memory_size=5, attn_type='original', attn_win=False, - gst=False, attn_norm="sigmoid", prenet_type="original", prenet_dropout=True, @@ -27,38 +25,41 @@ class Tacotron(nn.Module): location_attn=True, attn_K=5, separate_stopnet=True, - bidirectional_decoder=False): - super(Tacotron, self).__init__() - self.r = r - self.decoder_output_dim = decoder_output_dim - self.postnet_output_dim = postnet_output_dim - self.gst = gst - self.num_speakers = num_speakers - self.bidirectional_decoder = bidirectional_decoder - decoder_dim = 512 if num_speakers > 1 else 256 - encoder_dim = 512 if num_speakers > 1 else 256 + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + gst=False, + memory_size=5): + super(Tacotron, + self).__init__(num_chars, num_speakers, r, postnet_output_dim, + decoder_output_dim, attn_type, attn_win, + attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, forward_attn_mask, + location_attn, attn_K, separate_stopnet, + bidirectional_decoder, double_decoder_consistency, + ddc_r, gst) + decoder_in_features = 512 if num_speakers > 1 else 256 + encoder_in_features = 512 if num_speakers > 1 else 256 + speaker_embedding_dim = 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 - # embedding layer + # base model layers self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) - # boilerplate model - self.encoder = Encoder(encoder_dim) - self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win, + self.encoder = Encoder(encoder_in_features) + self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, proj_speaker_dim) - if self.bidirectional_decoder: - self.decoder_backward = copy.deepcopy(self.decoder) self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) # speaker embedding layers if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, 256) + self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_project_mel = nn.Sequential( - nn.Linear(256, proj_speaker_dim), nn.Tanh()) + nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh()) self.speaker_embeddings = None self.speaker_embeddings_projected = None # global style token layers @@ -68,28 +69,15 @@ class Tacotron(nn.Module): num_heads=4, num_style_tokens=10, embedding_dim=gst_embedding_dim) + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self._init_coarse_decoder() - def _init_states(self): - self.speaker_embeddings = None - self.speaker_embeddings_projected = None - def compute_speaker_embedding(self, speaker_ids): - if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError( - " [!] Model has speaker embedding layer but speaker_id is not provided" - ) - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - self.speaker_embeddings = self._compute_speaker_embedding( - speaker_ids) - self.speaker_embeddings_projected = self.speaker_project_mel( - self.speaker_embeddings).squeeze(1) - - def compute_gst(self, inputs, mel_specs): - gst_outputs = self.gst_layer(mel_specs) - inputs = self._add_speaker_embedding(inputs, gst_outputs) - return inputs - - def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): + def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None): """ Shapes: - characters: B x T_in @@ -98,45 +86,59 @@ class Tacotron(nn.Module): - speaker_ids: B x 1 """ self._init_states() - mask = sequence_mask(text_lengths).to(characters.device) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim inputs = self.embedding(characters) # B x speaker_embed_dim - self.compute_speaker_embedding(speaker_ids) + if speaker_ids is not None: + self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: # B x T_in x embed_dim + speaker_embed_dim inputs = self._concat_speaker_embedding(inputs, self.speaker_embeddings) - # B x T_in x encoder_dim + # B x T_in x encoder_in_features encoder_outputs = self.encoder(inputs) + # sequence masking + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token if self.gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) if self.num_speakers > 1: encoder_outputs = self._concat_speaker_embedding( encoder_outputs, self.speaker_embeddings) - # decoder_outputs: B x decoder_dim x T_out - # alignments: B x T_in x encoder_dim + # decoder_outputs: B x decoder_in_features x T_out + # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask, + encoder_outputs, mel_specs, input_mask, self.speaker_embeddings_projected) - # B x T_out x decoder_dim + # sequence masking + if output_mask is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x T_out x decoder_in_features postnet_outputs = self.postnet(decoder_outputs) + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) # B x T_out x posnet_dim postnet_outputs = self.last_linear(postnet_outputs) - # B x T_out x decoder_dim + # B x T_out x decoder_in_features decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() if self.bidirectional_decoder: - decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) + return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() def inference(self, characters, speaker_ids=None, style_mel=None): inputs = self.embedding(characters) self._init_states() - self.compute_speaker_embedding(speaker_ids) + if speaker_ids is not None: + self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, self.speaker_embeddings) @@ -152,28 +154,3 @@ class Tacotron(nn.Module): postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) return decoder_outputs, postnet_outputs, alignments, stop_tokens - - def _backward_inference(self, mel_specs, encoder_outputs, mask): - decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask, - self.speaker_embeddings_projected) - decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() - return decoder_outputs_b, alignments_b - - def _compute_speaker_embedding(self, speaker_ids): - speaker_embeddings = self.speaker_embedding(speaker_ids) - return speaker_embeddings.unsqueeze_(1) - - @staticmethod - def _add_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) - outputs = outputs + speaker_embeddings_ - return outputs - - @staticmethod - def _concat_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) - outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) - return outputs diff --git a/models/tacotron2.py b/models/tacotron2.py index 3e7adfca..bbce4be9 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -1,13 +1,15 @@ -import copy -import torch from math import sqrt + +import torch from torch import nn -from TTS.layers.tacotron2 import Encoder, Decoder, Postnet -from TTS.utils.generic_utils import sequence_mask + +from TTS.layers.gst_layers import GST +from TTS.layers.tacotron2 import Decoder, Encoder, Postnet +from TTS.models.tacotron_abstract import TacotronAbstract # TODO: match function arguments with tacotron -class Tacotron2(nn.Module): +class Tacotron2(TacotronAbstract): def __init__(self, num_chars, num_speakers, @@ -25,16 +27,22 @@ class Tacotron2(nn.Module): location_attn=True, attn_K=5, separate_stopnet=True, - bidirectional_decoder=False): - super(Tacotron2, self).__init__() - self.postnet_output_dim = postnet_output_dim - self.decoder_output_dim = decoder_output_dim - self.r = r - self.bidirectional_decoder = bidirectional_decoder - decoder_dim = 512 if num_speakers > 1 else 512 - encoder_dim = 512 if num_speakers > 1 else 512 + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + gst=False): + super(Tacotron2, + self).__init__(num_chars, num_speakers, r, postnet_output_dim, + decoder_output_dim, attn_type, attn_win, + attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, forward_attn_mask, + location_attn, attn_K, separate_stopnet, + bidirectional_decoder, double_decoder_consistency, + ddc_r, gst) + decoder_in_features = 512 if num_speakers > 1 else 512 + encoder_in_features = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 - # embedding layer + # base layers self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std @@ -42,20 +50,25 @@ class Tacotron2(nn.Module): if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding.weight.data.normal_(0, 0.3) - self.speaker_embeddings = None - self.speaker_embeddings_projected = None - self.encoder = Encoder(encoder_dim) - self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win, + self.encoder = Encoder(encoder_in_features) + self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, proj_speaker_dim) - if self.bidirectional_decoder: - self.decoder_backward = copy.deepcopy(self.decoder) self.postnet = Postnet(self.postnet_output_dim) - - def _init_states(self): - self.speaker_embeddings = None - self.speaker_embeddings_projected = None + # global style token layers + if self.gst: + gst_embedding_dim = encoder_in_features + self.gst_layer = GST(num_mel=80, + num_heads=4, + num_style_tokens=10, + embedding_dim=gst_embedding_dim) + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self._init_coarse_decoder() @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -63,31 +76,60 @@ class Tacotron2(nn.Module): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None): + def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): self._init_states() # compute mask for padding - mask = sequence_mask(text_lengths).to(text.device) + # B x T_in_max (boolean) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x D_embed x T_in_max embedded_inputs = self.embedding(text).transpose(1, 2) + # B x T_in_max x D_en encoder_outputs = self.encoder(embedded_inputs, text_lengths) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + # adding speaker embeddding to encoder output + # TODO: multi-speaker + # B x speaker_embed_dim + if speaker_ids is not None: + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + # B x T_in x embed_dim + speaker_embed_dim + encoder_outputs = self._add_speaker_embedding(encoder_outputs, + self.speaker_embeddings) + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask) + encoder_outputs, mel_specs, input_mask) + # sequence masking + if mel_lengths is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x mel_dim x T_out postnet_outputs = self.postnet(decoder_outputs) - postnet_outputs = decoder_outputs + postnet_outputs + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in decoder_outputs, postnet_outputs, alignments = self.shape_outputs( decoder_outputs, postnet_outputs, alignments) if self.bidirectional_decoder: - decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) + return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() def inference(self, text, speaker_ids=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + if speaker_ids is not None: + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + encoder_outputs = self._add_speaker_embedding(encoder_outputs, + self.speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) @@ -112,22 +154,16 @@ class Tacotron2(nn.Module): mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - def _backward_inference(self, mel_specs, encoder_outputs, mask): - decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask, - self.speaker_embeddings_projected) - decoder_outputs_b = decoder_outputs_b.transpose(1, 2) - return decoder_outputs_b, alignments_b - def _add_speaker_embedding(self, encoder_outputs, speaker_ids): - if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - speaker_embeddings = self.speaker_embedding(speaker_ids) + def _speaker_embedding_pass(self, encoder_outputs, speaker_ids): + # TODO: multi-speaker + # if hasattr(self, "speaker_embedding") and speaker_ids is None: + # raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") + # if hasattr(self, "speaker_embedding") and speaker_ids is not None: - speaker_embeddings.unsqueeze_(1) - speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), - encoder_outputs.size(1), - -1) - encoder_outputs = encoder_outputs + speaker_embeddings - return encoder_outputs + # speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + # encoder_outputs.size(1), + # -1) + # encoder_outputs = encoder_outputs + speaker_embeddings + # return encoder_outputs + pass