From 1d782487f5e6870170bea77758a7dd16d174379b Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 4 Aug 2020 14:43:31 -0300 Subject: [PATCH] use tacotron abstract for multispeaker common definitions --- mozilla_voice_tts/bin/train_tts.py | 1 + mozilla_voice_tts/tts/models/tacotron.py | 43 +++++++--------- mozilla_voice_tts/tts/models/tacotron2.py | 49 ++++++++----------- .../tts/models/tacotron_abstract.py | 15 ++++++ 4 files changed, 53 insertions(+), 55 deletions(-) diff --git a/mozilla_voice_tts/bin/train_tts.py b/mozilla_voice_tts/bin/train_tts.py index 1b9bc032..2b6cbfd0 100644 --- a/mozilla_voice_tts/bin/train_tts.py +++ b/mozilla_voice_tts/bin/train_tts.py @@ -536,6 +536,7 @@ def main(args): # pylint: disable=redefined-outer-name else: num_speakers = 0 speaker_embedding_dim = None + speaker_mapping = None model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) diff --git a/mozilla_voice_tts/tts/models/tacotron.py b/mozilla_voice_tts/tts/models/tacotron.py index 3837e63c..ac88133b 100644 --- a/mozilla_voice_tts/tts/models/tacotron.py +++ b/mozilla_voice_tts/tts/models/tacotron.py @@ -27,6 +27,8 @@ class Tacotron(TacotronAbstract): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + encoder_in_features=256, + decoder_in_features=256, speaker_embedding_dim=None, gst=False, gst_embedding_dim=256, @@ -40,39 +42,28 @@ class Tacotron(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, gst, gst_embedding_dim, gst_num_heads, gst_style_tokens) + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, + gst_num_heads, gst_style_tokens) - # init layer dims - decoder_in_features = 256 - encoder_in_features = 256 - - if speaker_embedding_dim is None: - # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim - self.embeddings_per_sample = False - speaker_embedding_dim = 256 - else: - # if speaker_embedding_dim is not None we need use speaker embedding per sample - self.embeddings_per_sample = True + # speaker embedding layers + if self.num_speakers > 1: + if not self.embeddings_per_sample: + speaker_embedding_dim = 256 + self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) # speaker and gst embeddings is concat in decoder input - if num_speakers > 1: - decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim - if self.gst: - decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim + if self.num_speakers > 1: + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) - - # speaker embedding layers - if num_speakers > 1: - if not self.embeddings_per_sample: - self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) + self.embedding.weight.data.normal_(0, 0.3) # base model layers - self.embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(encoder_in_features) - self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, @@ -93,7 +84,7 @@ class Tacotron(TacotronAbstract): # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - decoder_in_features, decoder_output_dim, ddc_r, memory_size, + self.decoder_in_features, decoder_output_dim, ddc_r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet) diff --git a/mozilla_voice_tts/tts/models/tacotron2.py b/mozilla_voice_tts/tts/models/tacotron2.py index 9aeeb3d2..9fa640b0 100644 --- a/mozilla_voice_tts/tts/models/tacotron2.py +++ b/mozilla_voice_tts/tts/models/tacotron2.py @@ -33,6 +33,8 @@ class Tacotron2(TacotronAbstract): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, speaker_embedding_dim=None, gst=False, gst_embedding_dim=512, @@ -45,38 +47,27 @@ class Tacotron2(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, gst, gst_embedding_dim, gst_num_heads, gst_style_tokens) + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, + gst_num_heads, gst_style_tokens) - # init layer dims - decoder_in_features = 512 - encoder_in_features = 512 - - if speaker_embedding_dim is None: - # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim - self.embeddings_per_sample = False - speaker_embedding_dim = 512 - else: - # if speaker_embedding_dim is not None we need use speaker embedding per sample - self.embeddings_per_sample = True + # speaker embedding layer + if self.num_speakers > 1: + if not self.embeddings_per_sample: + speaker_embedding_dim = 512 + self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) # speaker and gst embeddings is concat in decoder input - if num_speakers > 1: - decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim - if self.gst: - decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim - + if self.num_speakers > 1: + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) - # speaker embedding layer - if num_speakers > 1: - if not self.embeddings_per_sample: - self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) - # base model layers - self.encoder = Encoder(encoder_in_features) - self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet) @@ -85,16 +76,16 @@ class Tacotron2(TacotronAbstract): # global style token layers if self.gst: self.gst_layer = GST(num_mel=80, - num_heads=gst_num_heads, - num_style_tokens=gst_style_tokens, - embedding_dim=gst_embedding_dim) + num_heads=self.gst_num_heads, + num_style_tokens=self.gst_style_tokens, + embedding_dim=self.gst_embedding_dim) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, + self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet) diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index 6f3d32ad..0077f3e4 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -28,6 +28,9 @@ class TacotronAbstract(ABC, nn.Module): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, gst=False, gst_embedding_dim=512, gst_num_heads=4, @@ -57,6 +60,9 @@ class TacotronAbstract(ABC, nn.Module): self.location_attn = location_attn self.attn_K = attn_K self.separate_stopnet = separate_stopnet + self.encoder_in_features = encoder_in_features + self.decoder_in_features = decoder_in_features + self.speaker_embedding_dim = speaker_embedding_dim # layers self.embedding = None @@ -64,8 +70,17 @@ class TacotronAbstract(ABC, nn.Module): self.decoder = None self.postnet = None + # multispeaker + if self.speaker_embedding_dim is None: + # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim + self.embeddings_per_sample = False + else: + # if speaker_embedding_dim is not None we need use speaker embedding per sample + self.embeddings_per_sample = True + # global style token if self.gst: + self.decoder_in_features += gst_embedding_dim # add gst embedding dim self.gst_layer = None # model states