diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index db6cbdee..75a1a5cd 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -4,12 +4,12 @@ from abc import ABC, abstractmethod import torch from torch import nn -from TTS.layers.gst_layers import GST from TTS.utils.generic_utils import sequence_mask class TacotronAbstract(ABC, nn.Module): - def __init__(self, num_chars, + def __init__(self, + num_chars, num_speakers, r, postnet_output_dim=80, @@ -31,6 +31,7 @@ class TacotronAbstract(ABC, nn.Module): gst=False): """ Abstract Tacotron class """ super().__init__() + self.num_chars = num_chars self.r = r self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim @@ -39,6 +40,17 @@ class TacotronAbstract(ABC, nn.Module): self.bidirectional_decoder = bidirectional_decoder self.double_decoder_consistency = double_decoder_consistency self.ddc_r = ddc_r + self.attn_type = attn_type + self.attn_win = attn_win + self.attn_norm = attn_norm + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.forward_attn = forward_attn + self.trans_agent = trans_agent + self.forward_attn_mask = forward_attn_mask + self.location_attn = location_attn + self.attn_K = attn_K + self.separate_stopnet = separate_stopnet # layers self.embedding = None @@ -48,9 +60,16 @@ class TacotronAbstract(ABC, nn.Module): # global style token if self.gst: - gst_embedding_dim = None self.gst_layer = None + # model states + self.speaker_embeddings = None + self.speaker_embeddings_projected = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + ############################# # INIT FUNCTIONS ############################# @@ -114,7 +133,7 @@ class TacotronAbstract(ABC, nn.Module): (0, 0, 0, padding_size, 0, 0)) decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( encoder_outputs.detach(), mel_specs, input_mask) - scale_factor = self.decoder.r_init / self.decoder.r + # scale_factor = self.decoder.r_init / self.decoder.r alignments_backward = torch.nn.functional.interpolate( alignments_backward.transpose(1, 2), size=alignments.shape[1], @@ -141,6 +160,7 @@ class TacotronAbstract(ABC, nn.Module): def compute_gst(self, inputs, mel_specs): """ Compute global style token """ + # pylint: disable=not-callable gst_outputs = self.gst_layer(mel_specs) inputs = self._add_speaker_embedding(inputs, gst_outputs) return inputs