From b71f31eae4f74b6a426861f8085d73f2eb1c063a Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Fri, 10 Jul 2020 12:14:55 +0200 Subject: [PATCH 1/9] Added support for Tacotron2 GST + abbility to condition style input with wav or tokens --- config.json | 13 +++-- models/tacotron.py | 8 +-- models/tacotron2.py | 102 +++++++++++++++++++++--------------- models/tacotron_abstract.py | 36 ++++++++++--- synthesize.py | 2 +- utils/generic_utils.py | 16 +++++- utils/synthesis.py | 25 ++++++--- utils/text/cleaners.py | 7 +++ 8 files changed, 144 insertions(+), 65 deletions(-) diff --git a/config.json b/config.json index 23868a33..9c4b2271 100644 --- a/config.json +++ b/config.json @@ -131,8 +131,16 @@ // MULTI-SPEAKER and GST "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. - "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. - "use_gst": false, // TACOTRON ONLY: use global style tokens + "use_gst": true, // use global style tokens + "gst": { // gst parameter if gst is enabled + "gst_style_input": null, // Condition the style input either on a + // -> wave file [path to wave] or + // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} + // with the dictionary being len(dict) == len(gst_style_tokens). + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + }, // DATASETS "datasets": // List of datasets. They all merged and they get different speaker_ids. @@ -144,6 +152,5 @@ "meta_file_val": null } ] - } diff --git a/models/tacotron.py b/models/tacotron.py index ba42610c..e2733661 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -29,6 +29,9 @@ class Tacotron(TacotronAbstract): double_decoder_consistency=False, ddc_r=None, gst=False, + gst_embedding_dim=256, + gst_num_heads=4, + gst_style_tokens=10, memory_size=5): super(Tacotron, self).__init__(num_chars, num_speakers, r, postnet_output_dim, @@ -64,10 +67,9 @@ class Tacotron(TacotronAbstract): self.speaker_embeddings_projected = None # global style token layers if self.gst: - gst_embedding_dim = 256 self.gst_layer = GST(num_mel=80, - num_heads=4, - num_style_tokens=10, + num_heads=gst_num_heads, + num_style_tokens=gst_style_tokens, embedding_dim=gst_embedding_dim) # backward pass decoder if self.bidirectional_decoder: diff --git a/models/tacotron2.py b/models/tacotron2.py index 4a22b7fa..23a40d4f 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -28,7 +28,10 @@ class Tacotron2(TacotronAbstract): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, - gst=False): + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10): super(Tacotron2, self).__init__(num_chars, num_speakers, r, postnet_output_dim, decoder_output_dim, attn_type, attn_win, @@ -37,13 +40,17 @@ class Tacotron2(TacotronAbstract): location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, ddc_r, gst) - decoder_in_features = 512 if num_speakers > 1 else 512 + + # init layer dims + speaker_embedding_dim = 512 if num_speakers > 1 else 0 + gst_embedding_dim = gst_embedding_dim if self.gst else 0 + decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim encoder_in_features = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 # base layers self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, 512) + self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(encoder_in_features) self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, @@ -53,10 +60,9 @@ class Tacotron2(TacotronAbstract): self.postnet = Postnet(self.postnet_output_dim) # global style token layers if self.gst: - gst_embedding_dim = encoder_in_features self.gst_layer = GST(num_mel=80, - num_heads=4, - num_style_tokens=10, + num_heads=gst_num_heads, + num_style_tokens=gst_style_tokens, embedding_dim=gst_embedding_dim) # backward pass decoder if self.bidirectional_decoder: @@ -76,7 +82,6 @@ class Tacotron2(TacotronAbstract): return mel_outputs, mel_outputs_postnet, alignments def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): - self._init_states() # compute mask for padding # B x T_in_max (boolean) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) @@ -84,20 +89,24 @@ class Tacotron2(TacotronAbstract): embedded_inputs = self.embedding(text).transpose(1, 2) # B x T_in_max x D_en encoder_outputs = self.encoder(embedded_inputs, text_lengths) - # adding speaker embeddding to encoder output - # TODO: multi-speaker - # B x speaker_embed_dim - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: - # B x T_in x embed_dim + speaker_embed_dim - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) + embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] + embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) + else: + encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) + else: + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) - # global style token - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r decoder_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, input_mask) @@ -122,14 +131,25 @@ class Tacotron2(TacotronAbstract): return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, text, speaker_ids=None): + def inference(self, text, speaker_ids=None, style_mel=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) + embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] + embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) + else: + encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) + else: + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) @@ -138,14 +158,28 @@ class Tacotron2(TacotronAbstract): decoder_outputs, postnet_outputs, alignments) return decoder_outputs, postnet_outputs, alignments, stop_tokens - def inference_truncated(self, text, speaker_ids=None): + def inference_truncated(self, text, speaker_ids=None, style_mel=None): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + + if self.num_speakers > 1: + embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] + embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) + else: + encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) + else: + if hasattr(self, 'gst'): + # B x gst_dim + encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) + encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) @@ -153,17 +187,3 @@ class Tacotron2(TacotronAbstract): mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - - - def _speaker_embedding_pass(self, encoder_outputs, speaker_ids): - # TODO: multi-speaker - # if hasattr(self, "speaker_embedding") and speaker_ids is None: - # raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - # if hasattr(self, "speaker_embedding") and speaker_ids is not None: - - # speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), - # encoder_outputs.size(1), - # -1) - # encoder_outputs = encoder_outputs + speaker_embeddings - # return encoder_outputs - pass diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index 75a1a5cd..b7d9faf2 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -28,7 +28,10 @@ class TacotronAbstract(ABC, nn.Module): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, - gst=False): + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10): """ Abstract Tacotron class """ super().__init__() self.num_chars = num_chars @@ -36,6 +39,9 @@ class TacotronAbstract(ABC, nn.Module): self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim self.gst = gst + self.gst_embedding_dim = gst_embedding_dim + self.gst_num_heads = gst_num_heads + self.gst_style_tokens = gst_style_tokens self.num_speakers = num_speakers self.bidirectional_decoder = bidirectional_decoder self.double_decoder_consistency = double_decoder_consistency @@ -158,12 +164,28 @@ class TacotronAbstract(ABC, nn.Module): self.speaker_embeddings_projected = self.speaker_project_mel( self.speaker_embeddings).squeeze(1) - def compute_gst(self, inputs, mel_specs): - """ Compute global style token """ - # pylint: disable=not-callable - gst_outputs = self.gst_layer(mel_specs) - inputs = self._add_speaker_embedding(inputs, gst_outputs) - return inputs + def compute_gst(self, inputs, style_input): + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token in range(self.gst_style_tokens): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * 0 + else: + gst_outputs = self.gst_layer(style_input) + embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) + return inputs, embedded_gst @staticmethod def _add_speaker_embedding(outputs, speaker_embeddings): diff --git a/synthesize.py b/synthesize.py index 18048c2f..bd720123 100644 --- a/synthesize.py +++ b/synthesize.py @@ -27,7 +27,7 @@ def tts(model, t_1 = time.time() use_vocoder_model = vocoder_model is not None waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis( - model, text, C, use_cuda, ap, speaker_id, style_wav=False, + model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'], truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars, use_griffin_lim=(not use_vocoder_model), do_trim_silence=True) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index c806bdf3..8b4b1f12 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -149,6 +149,9 @@ def setup_model(num_chars, num_speakers, c): postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], memory_size=c.memory_size, attn_type=c.attention_type, attn_win=c.windowing, @@ -171,6 +174,9 @@ def setup_model(num_chars, num_speakers, c): postnet_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, @@ -348,10 +354,16 @@ def check_config(c): # paths _check_argument('output_path', c, restricted=True, val_type=str) - # multi-speaker gst + # multi-speaker _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - _check_argument('style_wav_for_test', c, restricted=True, val_type=str) + + # GST _check_argument('use_gst', c, restricted=True, val_type=bool) + _check_argument('gst_style_input', c, restricted=True, val_type=str) + _check_argument('gst', c, restricted=True, val_type=dict) + _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) + _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) + _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1) # datasets - checking only the first entry _check_argument('datasets', c, restricted=True, val_type=list) diff --git a/utils/synthesis.py b/utils/synthesis.py index ce76b0ec..e36a56e6 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -37,9 +37,11 @@ def numpy_to_tf(np_array, dtype): return tensor -def compute_style_mel(style_wav, ap): - style_mel = ap.melspectrogram( - ap.load_wav(style_wav)).expand_dims(0) +def compute_style_mel(style_wav, ap, cuda=False): + style_mel = torch.FloatTensor(ap.melspectrogram( + ap.load_wav(style_wav))).unsqueeze(0) + if cuda: + return style_mel.cuda() return style_mel @@ -129,10 +131,12 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(speaker_id): +def id_to_torch(speaker_id, cuda=False): if speaker_id is not None: speaker_id = np.asarray(speaker_id) speaker_id = torch.from_numpy(speaker_id).unsqueeze(0) + if cuda: + return speaker_id.cuda() return speaker_id @@ -185,14 +189,19 @@ def synthesis(model, """ # GST processing style_mel = None - if CONFIG.model == "TacotronGST" and style_wav is not None: - style_mel = compute_style_mel(style_wav, ap) + if CONFIG.use_gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) # preprocess the given text inputs = text_to_seqvec(text, CONFIG) # pass tensors to backend if backend == 'torch': - speaker_id = id_to_torch(speaker_id) - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + if not isinstance(style_mel, dict): + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) elif backend == 'tf': diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index f0a66f57..dd329f9c 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -91,6 +91,13 @@ def transliteration_cleaners(text): return text +def basic_german_cleaners(text): + '''Pipeline for Turkish text''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + # TODO: elaborate it def basic_turkish_cleaners(text): '''Pipeline for Turkish text''' From 998f33a104652a2253e570b9de4431d55fb94511 Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Fri, 10 Jul 2020 12:46:43 +0200 Subject: [PATCH 2/9] No need to query every token when none were passed --- config.json | 2 +- models/tacotron_abstract.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/config.json b/config.json index 9c4b2271..bf2ad383 100644 --- a/config.json +++ b/config.json @@ -136,7 +136,7 @@ "gst_style_input": null, // Condition the style input either on a // -> wave file [path to wave] or // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} - // with the dictionary being len(dict) == len(gst_style_tokens). + // with the dictionary being len(dict) <= len(gst_style_tokens). "gst_embedding_dim": 512, "gst_num_heads": 4, "gst_style_tokens": 10 diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index b7d9faf2..c868a18a 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module): gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token in range(self.gst_style_tokens): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * 0 else: gst_outputs = self.gst_layer(style_input) embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) From 2a840222dac441e199c6a4219fe341b996e3a9bc Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Sun, 12 Jul 2020 10:40:33 +0200 Subject: [PATCH 3/9] fix fft_size key error --- models/tacotron_abstract.py | 1 + tests/test_config.json | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index c868a18a..3d7564b5 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -177,6 +177,7 @@ class TacotronAbstract(ABC, nn.Module): elif style_input is None: gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) else: + # pylint: disable=not-callable gst_outputs = self.gst_layer(style_input) embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) return inputs, embedded_gst diff --git a/tests/test_config.json b/tests/test_config.json index 6da13bfc..b34a53a8 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -2,7 +2,7 @@ "audio":{ "audio_processor": "audio", // to use dictate different audio processors, if available. "num_mels": 80, // size of the mel spec frame. - "num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame. + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled. "frame_length_ms": null, // stft window length in ms. "frame_shift_ms": null, // stft window hop-lengh in ms. @@ -51,5 +51,15 @@ "output_path": "result", "min_seq_len": 0, "max_seq_len": 300, - "log_dir": "tests/outputs/" + "log_dir": "tests/outputs/", + + "use_speaker_embedding": false, + "use_gst": false, + "gst": { + "gst_style_input": null, + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + }, + } From 564fc0aab456aef569b04f68e08d636c25dcfd34 Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Sun, 12 Jul 2020 12:33:13 +0200 Subject: [PATCH 4/9] pylint --- models/tacotron_abstract.py | 3 +-- tests/test_config.json | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index 3d7564b5..c8f71312 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -177,8 +177,7 @@ class TacotronAbstract(ABC, nn.Module): elif style_input is None: gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) else: - # pylint: disable=not-callable - gst_outputs = self.gst_layer(style_input) + gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) return inputs, embedded_gst diff --git a/tests/test_config.json b/tests/test_config.json index b34a53a8..450cb23a 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -60,6 +60,5 @@ "gst_embedding_dim": 512, "gst_num_heads": 4, "gst_style_tokens": 10 - }, - + } } From 6d3ddae64e16933c3f2a26db928676afed41d550 Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Sun, 12 Jul 2020 14:07:44 +0200 Subject: [PATCH 5/9] tacotrongst test + test fixes --- models/tacotron2.py | 12 +++--- tests/outputs/dummy_model_config.json | 10 ++++- tests/test_config.json | 11 +---- tests/test_tacotron2_model.py | 62 +++++++++++++++++++++++++-- tests/test_tacotron_model.py | 8 ++-- 5 files changed, 79 insertions(+), 24 deletions(-) diff --git a/models/tacotron2.py b/models/tacotron2.py index 23a40d4f..75ae9bef 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -93,14 +93,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) @@ -138,14 +138,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) @@ -168,14 +168,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) diff --git a/tests/outputs/dummy_model_config.json b/tests/outputs/dummy_model_config.json index 36fac3e5..c35c7495 100644 --- a/tests/outputs/dummy_model_config.json +++ b/tests/outputs/dummy_model_config.json @@ -83,6 +83,14 @@ "use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "text_cleaner": "phoneme_cleaners", - "use_speaker_embedding": false // whether to use additional embeddings for separate speakers + "use_speaker_embedding": false, // whether to use additional embeddings for separate speakers + "use_gst": false, + "gst": { + "gst_style_input": null, + "gst_embedding_dim": 256, + "gst_num_heads": 4, + "gst_style_tokens": 10 + } } + diff --git a/tests/test_config.json b/tests/test_config.json index 450cb23a..31c2cd87 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -51,14 +51,5 @@ "output_path": "result", "min_seq_len": 0, "max_seq_len": 300, - "log_dir": "tests/outputs/", - - "use_speaker_embedding": false, - "use_gst": false, - "gst": { - "gst_style_input": null, - "gst_embedding_dim": 512, - "gst_num_heads": 4, - "gst_style_tokens": 10 - } + "log_dir": "tests/outputs/" } diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index ae9f20a2..5dfd7759 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -22,7 +22,7 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TacotronTrainTest(unittest.TestCase): def test_train_step(self): - input = torch.randint(0, 24, (8, 128)).long().to(device) + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8, )).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) @@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase): for idx in mel_lengths: stop_targets[:, int(idx.item()):, 0] = 1.0 - stop_targets = stop_targets.view(input.shape[0], + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() @@ -52,7 +52,63 @@ class TacotronTrainTest(unittest.TestCase): optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 + + +class TacotronGSTTrainTest(unittest.TestCase): + def test_train_step(self): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(num_chars=24, + gst=True, + r=c.r, + num_speakers=5).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(5): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index 2bbb3c8d..00cc38df 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -31,7 +31,7 @@ class TacotronTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device) @@ -49,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase): model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['num_freq'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size @@ -93,7 +93,7 @@ class TacotronGSTTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 120, (8, )).long().to(device) mel_lengths[-1] = 120 stop_targets = torch.zeros(8, 120, 1).float().to(device) @@ -113,7 +113,7 @@ class TacotronGSTTrainTest(unittest.TestCase): num_chars=32, num_speakers=5, gst=True, - postnet_output_dim=c.audio['num_freq'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size From 69367bd2aeca0723b0f23196be5c4c052bdb8b82 Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Mon, 13 Jul 2020 08:50:39 +0200 Subject: [PATCH 6/9] override compute_gst in tacotron2 model --- models/tacotron2.py | 25 ++++++++++++++++++++++++- models/tacotron_abstract.py | 22 ++++++---------------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/models/tacotron2.py b/models/tacotron2.py index 75ae9bef..03f45034 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -47,17 +47,22 @@ class Tacotron2(TacotronAbstract): decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim encoder_in_features = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 - # base layers + + # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) + + # speaker embedding layer if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) + self.encoder = Encoder(encoder_in_features) self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, proj_speaker_dim) self.postnet = Postnet(self.postnet_output_dim) + # global style token layers if self.gst: self.gst_layer = GST(num_mel=80, @@ -81,6 +86,24 @@ class Tacotron2(TacotronAbstract): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments + def compute_gst(self, inputs, style_input): + """ Compute global style token """ + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + else: + gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable + embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) + return inputs, embedded_gst + def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): # compute mask for padding # B x T_in_max (boolean) diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index c8f71312..7038a5e9 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -164,22 +164,12 @@ class TacotronAbstract(ABC, nn.Module): self.speaker_embeddings_projected = self.speaker_project_mel( self.speaker_embeddings).squeeze(1) - def compute_gst(self, inputs, style_input): - device = inputs.device - if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - else: - gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable - embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) - return inputs, embedded_gst + def compute_gst(self, inputs, mel_specs): + """ Compute global style token """ + # pylint: disable=not-callable + gst_outputs = self.gst_layer(mel_specs) + inputs = self._add_speaker_embedding(inputs, gst_outputs) + return inputs @staticmethod def _add_speaker_embedding(outputs, speaker_embeddings): From 18007e389d5d015dc1ce86fbebb9958cfca5883d Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Mon, 13 Jul 2020 08:51:37 +0200 Subject: [PATCH 7/9] small gst config change --- tests/test_tacotron2_model.py | 60 ++--------------------------------- utils/generic_utils.py | 2 +- 2 files changed, 3 insertions(+), 59 deletions(-) diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index 5dfd7759..43003227 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -2,7 +2,6 @@ import os import copy import torch import unittest -import numpy as np from torch import optim from torch import nn @@ -21,7 +20,8 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TacotronTrainTest(unittest.TestCase): - def test_train_step(self): + @staticmethod + def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8, )).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] @@ -71,59 +71,3 @@ class TacotronTrainTest(unittest.TestCase): ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) count += 1 - - -class TacotronGSTTrainTest(unittest.TestCase): - def test_train_step(self): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) - input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) - mel_lengths[0] = 30 - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) - - for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 - - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - criterion = MSELossMasked(seq_len_norm=False).to(device) - criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, - gst=True, - r=c.r, - num_speakers=5).to(device) - model.train() - model_ref = copy.deepcopy(model) - count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizer = optim.Adam(model.parameters(), lr=c.lr) - for i in range(5): - mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) - assert torch.sigmoid(stop_tokens).data.max() <= 1.0 - assert torch.sigmoid(stop_tokens).data.min() >= 0.0 - optimizer.zero_grad() - loss = criterion(mel_out, mel_spec, mel_lengths) - stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss - loss.backward() - optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): - # ignore pre-higway layer since it works conditional - # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) - count += 1 diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 8b4b1f12..3bb99e08 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -359,8 +359,8 @@ def check_config(c): # GST _check_argument('use_gst', c, restricted=True, val_type=bool) - _check_argument('gst_style_input', c, restricted=True, val_type=str) _check_argument('gst', c, restricted=True, val_type=dict) + _check_argument('gst_style_input', c['gst'], restricted=True, val_type=str) _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1) From c865dd86bc1c99264f37a1d913e1e413a538ea87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 13 Jul 2020 10:33:55 +0200 Subject: [PATCH 8/9] update comment --- utils/text/cleaners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index dd329f9c..6d1ace08 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -92,7 +92,7 @@ def transliteration_cleaners(text): def basic_german_cleaners(text): - '''Pipeline for Turkish text''' + '''Pipeline for German text''' text = lowercase(text) text = collapse_whitespace(text) return text From 3efea6e827133f5b85e3c7ec750661377d7def3b Mon Sep 17 00:00:00 2001 From: thllwg Date: Thu, 13 Aug 2020 09:07:57 +0200 Subject: [PATCH 9/9] Import fix --- mozilla_voice_tts/bin/train_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mozilla_voice_tts/bin/train_encoder.py b/mozilla_voice_tts/bin/train_encoder.py index f9bfea7f..6f41a431 100644 --- a/mozilla_voice_tts/bin/train_encoder.py +++ b/mozilla_voice_tts/bin/train_encoder.py @@ -16,10 +16,10 @@ from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder from mozilla_voice_tts.speaker_encoder.visual import plot_embeddings from mozilla_voice_tts.tts.datasets.preprocess import load_meta_data -from mozilla_voice_tts.tts.utils.generic_utils import ( +from mozilla_voice_tts.utils.generic_utils import ( create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from mozilla_voice_tts.tts.utils.io import copy_config_file, load_config +from mozilla_voice_tts.utils.io import copy_config_file, load_config from mozilla_voice_tts.utils.audio import AudioProcessor from mozilla_voice_tts.utils.generic_utils import count_parameters from mozilla_voice_tts.utils.radam import RAdam