From d9540a5857d79dcfd260c776988cd03ad6d02b2a Mon Sep 17 00:00:00 2001 From: Edresson Date: Sun, 25 Oct 2020 15:08:28 -0300 Subject: [PATCH] add blank token in sequence for encrease glowtts results --- TTS/bin/train_glow_tts.py | 1 + TTS/bin/train_tts.py | 1 + TTS/tts/configs/glow_tts_gated_conv.json | 2 ++ TTS/tts/configs/glow_tts_tdsep.json | 2 ++ TTS/tts/datasets/TTSDataset.py | 6 ++++-- TTS/tts/utils/synthesis.py | 7 +++++-- TTS/tts/utils/text/__init__.py | 14 +++++++++++--- 7 files changed, 26 insertions(+), 7 deletions(-) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 7ffca36e..f4d04abb 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -47,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 4c615b99..e4f8bf7a 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -51,6 +51,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, diff --git a/TTS/tts/configs/glow_tts_gated_conv.json b/TTS/tts/configs/glow_tts_gated_conv.json index 696bdaf7..5c30e0bc 100644 --- a/TTS/tts/configs/glow_tts_gated_conv.json +++ b/TTS/tts/configs/glow_tts_gated_conv.json @@ -51,6 +51,8 @@ // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // }, + "add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model. + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", diff --git a/TTS/tts/configs/glow_tts_tdsep.json b/TTS/tts/configs/glow_tts_tdsep.json index 67047523..25d41291 100644 --- a/TTS/tts/configs/glow_tts_tdsep.json +++ b/TTS/tts/configs/glow_tts_tdsep.json @@ -51,6 +51,8 @@ // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // }, + "add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model. + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index ab8f3f88..7b671397 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -17,6 +17,7 @@ class MyDataset(Dataset): ap, meta_data, tp=None, + add_blank=False, batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), @@ -55,6 +56,7 @@ class MyDataset(Dataset): self.max_seq_len = max_seq_len self.ap = ap self.tp = tp + self.add_blank = add_blank self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -88,7 +90,7 @@ class MyDataset(Dataset): phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=False, - tp=self.tp) + tp=self.tp, add_blank=self.add_blank) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @@ -127,7 +129,7 @@ class MyDataset(Dataset): text = self._load_or_generate_phoneme_sequence(wav_file, text) else: text = np.asarray(text_to_sequence(text, [self.cleaners], - tp=self.tp), + tp=self.tp, add_blank=self.add_blank), dtype=np.int32) assert text.size > 0, self.items[idx][1] diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 0dfea5cc..3d2dd13c 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -14,10 +14,13 @@ def text_to_seqvec(text, CONFIG): seq = np.asarray( phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, + add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) + seq = np.asarray( + text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, + add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) return seq diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 33972f25..eab7a689 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -57,6 +57,10 @@ def text2phone(text, language): return ph +def intersperse(sequence, token): + result = [token] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result def pad_with_eos_bos(phoneme_sequence, tp=None): # pylint: disable=global-statement @@ -69,8 +73,7 @@ def pad_with_eos_bos(phoneme_sequence, tp=None): return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] - -def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False): # pylint: disable=global-statement global _phonemes_to_id if tp: @@ -88,6 +91,8 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp= # Append EOS char if enable_eos_bos: sequence = pad_with_eos_bos(sequence, tp=tp) + if add_blank: + sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) return sequence @@ -107,7 +112,7 @@ def sequence_to_phoneme(sequence, tp=None): return result.replace('}{', ' ') -def text_to_sequence(text, cleaner_names, tp=None): +def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded @@ -137,6 +142,9 @@ def text_to_sequence(text, cleaner_names, tp=None): _clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) + + if add_blank: + sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) return sequence