From aa06a0609d0a8d5e5b2780ab0b6f059ae7b0398f Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 4 Jun 2020 14:26:30 +0200 Subject: [PATCH] tacotron abstract class --- models/tacotron_abstract.py | 160 ++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 models/tacotron_abstract.py diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py new file mode 100644 index 00000000..db6cbdee --- /dev/null +++ b/models/tacotron_abstract.py @@ -0,0 +1,160 @@ +import copy +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from TTS.layers.gst_layers import GST +from TTS.utils.generic_utils import sequence_mask + + +class TacotronAbstract(ABC, nn.Module): + def __init__(self, num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type='original', + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + gst=False): + """ Abstract Tacotron class """ + super().__init__() + self.r = r + self.decoder_output_dim = decoder_output_dim + self.postnet_output_dim = postnet_output_dim + self.gst = gst + self.num_speakers = num_speakers + self.bidirectional_decoder = bidirectional_decoder + self.double_decoder_consistency = double_decoder_consistency + self.ddc_r = ddc_r + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # global style token + if self.gst: + gst_embedding_dim = None + self.gst_layer = None + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_states(self): + self.speaker_embeddings = None + self.speaker_embeddings_projected = None + + def _init_backward_decoder(self): + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + device = text_lengths.device + input_mask = sequence_mask(text_lengths).to(device) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """ Run backwards decoder """ + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask, + self.speaker_embeddings_projected) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, + input_mask): + """ Double Decoder Consistency """ + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, + (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask) + scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), + size=alignments.shape[1], + mode='nearest').transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_speaker_embedding(self, speaker_ids): + """ Compute speaker embedding vectors """ + if hasattr(self, "speaker_embedding") and speaker_ids is None: + raise RuntimeError( + " [!] Model has speaker embedding layer but speaker_id is not provided" + ) + if hasattr(self, "speaker_embedding") and speaker_ids is not None: + self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1) + if hasattr(self, "speaker_project_mel") and speaker_ids is not None: + self.speaker_embeddings_projected = self.speaker_project_mel( + self.speaker_embeddings).squeeze(1) + + def compute_gst(self, inputs, mel_specs): + """ Compute global style token """ + gst_outputs = self.gst_layer(mel_specs) + inputs = self._add_speaker_embedding(inputs, gst_outputs) + return inputs + + @staticmethod + def _add_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) + outputs = outputs + speaker_embeddings_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + return outputs