diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index c82b821e..9f4b2b68 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -32,7 +32,7 @@ class GSTConfig(Coqpit): @dataclass -class CharactersConfig: +class CharactersConfig(Coqpit): """Defines character or phoneme set used by the model""" pad: str = None @@ -41,6 +41,7 @@ class CharactersConfig: characters: str = None punctuations: str = None phonemes: str = None + unique: bool = True # for backwards compatibility of models trained with char sets with duplicates def check_values( self, diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index 8b1ed20c..5c86f500 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -11,6 +11,7 @@ class TacotronConfig(BaseTTSConfig): """Defines parameters for Tacotron based models.""" model: str = "tacotron" + use_gst: bool = False gst: GSTConfig = None gst_style_input: str = None # model specific params @@ -61,10 +62,3 @@ class TacotronConfig(BaseTTSConfig): decoder_ssim_alpha: float = 0.25 postnet_ssim_alpha: float = 0.25 ga_alpha: float = 5.0 - - -@dataclass -class Tacotron2Config(TacotronConfig): - """Defines parameters for Tacotron2 based models.""" - - model: str = "tacotron2" diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 1ffe9786..89d98e9f 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -41,6 +41,7 @@ class Tacotron(TacotronAbstract): encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None. + use_gst (bool, optional): enable/disable Global style token module. gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size``` output frames to the prenet. @@ -71,6 +72,7 @@ class Tacotron(TacotronAbstract): encoder_in_features=256, decoder_in_features=256, speaker_embedding_dim=None, + use_gst=False, gst=None, memory_size=5, ): @@ -98,6 +100,7 @@ class Tacotron(TacotronAbstract): encoder_in_features, decoder_in_features, speaker_embedding_dim, + use_gst, gst, ) @@ -142,7 +145,7 @@ class Tacotron(TacotronAbstract): self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference # global style token layers - if self.gst: + if self.gst and self.use_gst: self.gst_layer = GST( num_mel=decoder_output_dim, speaker_embedding_dim=speaker_embedding_dim, @@ -191,7 +194,7 @@ class Tacotron(TacotronAbstract): # sequence masking encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) # global style token - if self.gst: + if self.gst and self.use_gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) # speaker embedding @@ -247,7 +250,7 @@ class Tacotron(TacotronAbstract): def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None): inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) - if self.gst: + if self.gst and self.use_gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) if self.num_speakers > 1: diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 1945a6f7..fded8f87 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -41,6 +41,7 @@ class Tacotron2(TacotronAbstract): encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None. + use_gst (bool, optional): enable/disable Global style token module. gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. """ @@ -69,6 +70,7 @@ class Tacotron2(TacotronAbstract): encoder_in_features=512, decoder_in_features=512, speaker_embedding_dim=None, + use_gst=False, gst=None, ): super().__init__( @@ -95,6 +97,7 @@ class Tacotron2(TacotronAbstract): encoder_in_features, decoder_in_features, speaker_embedding_dim, + use_gst, gst, ) @@ -136,7 +139,7 @@ class Tacotron2(TacotronAbstract): self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference # global style token layers - if self.gst: + if self.gst and use_gst: self.gst_layer = GST( num_mel=decoder_output_dim, speaker_embedding_dim=speaker_embedding_dim, @@ -190,7 +193,7 @@ class Tacotron2(TacotronAbstract): embedded_inputs = self.embedding(text).transpose(1, 2) # B x T_in_max x D_en encoder_outputs = self.encoder(embedded_inputs, text_lengths) - if self.gst: + if self.gst and self.use_gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) if self.num_speakers > 1: @@ -246,7 +249,7 @@ class Tacotron2(TacotronAbstract): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) - if self.gst: + if self.gst and self.use_gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) if self.num_speakers > 1: diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 42411656..e684ce7c 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -33,6 +33,7 @@ class TacotronAbstract(ABC, nn.Module): encoder_in_features=512, decoder_in_features=512, speaker_embedding_dim=None, + use_gst=False, gst=None, ): """Abstract Tacotron class""" @@ -41,6 +42,7 @@ class TacotronAbstract(ABC, nn.Module): self.r = r self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim + self.use_gst = use_gst self.gst = gst self.num_speakers = num_speakers self.bidirectional_decoder = bidirectional_decoder @@ -77,7 +79,7 @@ class TacotronAbstract(ABC, nn.Module): self.embeddings_per_sample = True # global style token - if self.gst: + if self.gst and use_gst: self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim self.gst_layer = None @@ -186,18 +188,18 @@ class TacotronAbstract(ABC, nn.Module): """Compute global style token""" device = inputs.device if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device) + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device) for k_token, v_amplifier in style_input.items(): key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device) else: gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index b81a75ff..b0e53f33 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -22,6 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): r=c.r, postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), decoder_output_dim=c.audio["num_mels"], + use_gst=c.use_gst, gst=c.gst, memory_size=c.memory_size, attn_type=c.attention_type, @@ -48,6 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): r=c.r, postnet_output_dim=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"], + use_gst=c.use_gst, gst=c.gst, attn_type=c.attention_type, attn_win=c.windowing,