From 05d9543ed8732a7aa9e1bfb21aaa39ad118108e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 3 May 2021 16:44:34 +0200 Subject: [PATCH] init GST module using gst config in Tacotron models --- TTS/tts/models/tacotron.py | 36 +++++++------------------ TTS/tts/models/tacotron2.py | 41 ++++++++--------------------- TTS/tts/models/tacotron_abstract.py | 12 ++------- TTS/tts/utils/generic_utils.py | 16 +++-------- 4 files changed, 26 insertions(+), 79 deletions(-) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 85d90116..1ffe9786 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -41,11 +41,7 @@ class Tacotron(TacotronAbstract): encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None. - gst (bool, optional): enable/disable global style token learning. Defaults to False. - gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512. - gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4. - gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10. - gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False. + gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size``` output frames to the prenet. """ @@ -75,12 +71,8 @@ class Tacotron(TacotronAbstract): encoder_in_features=256, decoder_in_features=256, speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=256, - gst_num_heads=4, - gst_style_tokens=10, + gst=None, memory_size=5, - gst_use_speaker_embedding=False, ): super().__init__( num_chars, @@ -107,10 +99,6 @@ class Tacotron(TacotronAbstract): decoder_in_features, speaker_embedding_dim, gst, - gst_embedding_dim, - gst_num_heads, - gst_style_tokens, - gst_use_speaker_embedding, ) # speaker embedding layers @@ -156,13 +144,11 @@ class Tacotron(TacotronAbstract): # global style token layers if self.gst: self.gst_layer = GST( - num_mel=80, - num_heads=gst_num_heads, - num_style_tokens=gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim - if self.embeddings_per_sample and self.gst_use_speaker_embedding - else None, + num_mel=decoder_output_dim, + speaker_embedding_dim=speaker_embedding_dim, + num_heads=gst.gst_num_heads, + num_style_tokens=gst.gst_num_style_tokens, + gst_embedding_dim=gst.gst_embedding_dim, ) # backward pass decoder if self.bidirectional_decoder: @@ -207,9 +193,7 @@ class Tacotron(TacotronAbstract): # global style token if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) # speaker embedding if self.num_speakers > 1: if not self.embeddings_per_sample: @@ -265,9 +249,7 @@ class Tacotron(TacotronAbstract): encoder_outputs = self.encoder(inputs) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 44c81735..1945a6f7 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -41,11 +41,7 @@ class Tacotron2(TacotronAbstract): encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None. - gst (bool, optional): enable/disable global style token learning. Defaults to False. - gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512. - gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4. - gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10. - gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False. + gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. """ def __init__( @@ -73,11 +69,7 @@ class Tacotron2(TacotronAbstract): encoder_in_features=512, decoder_in_features=512, speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False, + gst=None, ): super().__init__( num_chars, @@ -104,10 +96,6 @@ class Tacotron2(TacotronAbstract): decoder_in_features, speaker_embedding_dim, gst, - gst_embedding_dim, - gst_num_heads, - gst_style_tokens, - gst_use_speaker_embedding, ) # speaker embedding layer @@ -150,14 +138,13 @@ class Tacotron2(TacotronAbstract): # global style token layers if self.gst: self.gst_layer = GST( - num_mel=80, - num_heads=self.gst_num_heads, - num_style_tokens=self.gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim - if self.embeddings_per_sample and self.gst_use_speaker_embedding - else None, + num_mel=decoder_output_dim, + speaker_embedding_dim=speaker_embedding_dim, + num_heads=gst.gst_num_heads, + num_style_tokens=gst.gst_num_style_tokens, + gst_embedding_dim=gst.gst_embedding_dim, ) + # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() @@ -205,9 +192,7 @@ class Tacotron2(TacotronAbstract): encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim @@ -263,9 +248,7 @@ class Tacotron2(TacotronAbstract): if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] @@ -286,9 +269,7 @@ class Tacotron2(TacotronAbstract): if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) if self.num_speakers > 1: if not self.embeddings_per_sample: diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index c6bdb19e..42411656 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -33,11 +33,7 @@ class TacotronAbstract(ABC, nn.Module): encoder_in_features=512, decoder_in_features=512, speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False, + gst=None, ): """Abstract Tacotron class""" super().__init__() @@ -46,10 +42,6 @@ class TacotronAbstract(ABC, nn.Module): self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim self.gst = gst - self.gst_embedding_dim = gst_embedding_dim - self.gst_num_heads = gst_num_heads - self.gst_num_style_tokens = gst_num_style_tokens - self.gst_use_speaker_embedding = gst_use_speaker_embedding self.num_speakers = num_speakers self.bidirectional_decoder = bidirectional_decoder self.double_decoder_consistency = double_decoder_consistency @@ -86,7 +78,7 @@ class TacotronAbstract(ABC, nn.Module): # global style token if self.gst: - self.decoder_in_features += gst_embedding_dim # add gst embedding dim + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim self.gst_layer = None # model states diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 9f17da0b..e6934bc9 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -14,7 +14,7 @@ def sequence_mask(sequence_length, max_len=None): def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) - MyModel = find_module("TTS.tts.models", c.model.lower()) + MyModel = find_module("TTS.tts.models", c.model.lower()) if c.model.lower() in "tacotron": model = MyModel( num_chars=num_chars + getattr(c, "add_blank", False), @@ -23,17 +23,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), decoder_output_dim=c.audio["num_mels"], gst=c.use_gst, - gst_embedding_dim=c.gst["gst_embedding_dim"], - gst_num_heads=c.gst["gst_num_heads"], - gst_style_tokens=c.gst["gst_style_tokens"], - gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], memory_size=c.memory_size, attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False, + prenet_dropout_at_inference=c.prenet_dropout_at_inference, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, @@ -52,17 +48,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): r=c.r, postnet_output_dim=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"], - gst=c.gst is not None, - gst_embedding_dim=None if c.gst is None else c.gst['gst_embedding_dim'], - gst_num_heads=None if c.gst is None else c.gst['gst_num_heads'], - gst_num_style_tokens=None if c.gst is None else c.gst['gst_num_style_tokens'], - gst_use_speaker_embedding=None if c.gst is None else c.gst['gst_use_speaker_embedding'], + gst=c.gst, attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False, + prenet_dropout_at_inference=c.prenet_dropout_at_inference, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask,