From 165e5814af52801262ce972ed3d6d4a52d0e84e3 Mon Sep 17 00:00:00 2001 From: Katsuya Iida Date: Wed, 1 Sep 2021 16:33:15 +0900 Subject: [PATCH 01/52] Update Japanese phonemizer (#758) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update default ja vocoder * update * Japanese phonemizer test * Run make style Co-authored-by: Eren Gölge --- TTS/.models.json | 2 +- TTS/tts/utils/text/japanese/phonemizer.py | 89 +++++++++++++++++++- tests/text_tests/test_japanese_phonemizer.py | 6 +- 3 files changed, 93 insertions(+), 4 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index 6ba67b01..d3c56b94 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -157,7 +157,7 @@ "kokoro": { "tacotron2-DDC": { "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.15/tts_models--jp--kokoro--tacotron2-DDC.zip", - "default_vocoder": "vocoder_models/universal/libri-tts/wavegrad", + "default_vocoder": "vocoder_models/ja/kokoro/hifigan_v1", "description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.", "author": "@kaiidams", "commit": "401fbd89" diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py index a4629a30..969becfd 100644 --- a/TTS/tts/utils/text/japanese/phonemizer.py +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -2,8 +2,10 @@ # compatible with Julius https://github.com/julius-speech/segmentation-kit import re +import unicodedata import MeCab +from num2words import num2words _CONVRULES = [ # Conversion of 2 letters @@ -373,8 +375,93 @@ def text2kata(text: str) -> str: return hira2kata("".join(res)) +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + def japanese_text_to_phonemes(text: str) -> str: """Convert Japanese text to phonemes.""" - res = text2kata(text) + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) res = kata2phoneme(res) return res.replace(" ", "") diff --git a/tests/text_tests/test_japanese_phonemizer.py b/tests/text_tests/test_japanese_phonemizer.py index b3b1ece3..423b79b9 100644 --- a/tests/text_tests/test_japanese_phonemizer.py +++ b/tests/text_tests/test_japanese_phonemizer.py @@ -5,11 +5,13 @@ from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes _TEST_CASES = """ どちらに行きますか?/dochiraniikimasuka? 今日は温泉に、行きます。/kyo:waoNseNni,ikimasu. -「A」から「Z」までです。/AkaraZmadedesu. +「A」から「Z」までです。/e:karazeqtomadedesu. そうですね!/so:desune! クジラは哺乳類です。/kujirawahonyu:ruidesu. ヴィディオを見ます。/bidioomimasu. -ky o: w a o N s e N n i , i k i m a s u ./kyo:waoNseNni,ikimasu. +今日は8月22日です/kyo:wahachigatsuniju:ninichidesu +xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma +値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu """ From fba257104d776b58ac0bfa86fc3d08727046d1fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 6 Jul 2021 09:48:00 +0200 Subject: [PATCH 02/52] Compute F0 using librosa --- TTS/tts/datasets/TTSDataset.py | 14 +++++++++++ TTS/utils/audio.py | 46 ++++++++++++++++++++++++++-------- tests/test_audio_processor.py | 7 ++++++ 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 5d38243e..7ad52797 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -22,6 +22,7 @@ class TTSDataset(Dataset): compute_linear_spec: bool, ap: AudioProcessor, meta_data: List[List], + compute_f0: bool = False, characters: Dict = None, custom_symbols: List = None, add_blank: bool = False, @@ -54,6 +55,8 @@ class TTSDataset(Dataset): meta_data (list): List of dataset instances. + compute_f0 (bool): compute f0 if True. Defaults to False. + characters (dict): `dict` of custom text characters used for converting texts to sequences. custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own @@ -103,6 +106,7 @@ class TTSDataset(Dataset): self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav + self.compute_f0 = compute_f0 self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -458,6 +462,16 @@ class TTSDataset(Dataset): wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) + # compute f0 + # TODO: compare perf in collate_fn vs in load_data + pitch = None + if self.compute_f0: + pitch = [self.ap.compute_f0(w).astype("float32") for w in wav] + pitch = prepare_tensor(pitch, self.outputs_per_step) + pitch = pitch.transpose(0, 2, 1) + assert mel.shape[1] == pitch.shape[1] + pitch = torch.FloatTensor(pitch).contiguous() + # collate attention alignments if batch[0]["attn"] is not None: attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 0a343fbf..e027b060 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -623,17 +623,41 @@ class AudioProcessor(object): return 0, pad return pad // 2, pad // 2 + pad % 2 - ### Compute F0 ### - # TODO: pw causes some dep issues - # def compute_f0(self, x): - # f0, t = pw.dio( - # x.astype(np.double), - # fs=self.sample_rate, - # f0_ceil=self.mel_fmax, - # frame_period=1000 * self.hop_length / self.sample_rate, - # ) - # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # return f0 + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + # f0, t = pw.dio( + # x.astype(np.double), + # fs=self.sample_rate, + # f0_ceil=self.mel_fmax, + # frame_period=1000 * self.hop_length / self.sample_rate, + # ) + # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size) + f0, _, _ = librosa.pyin( + x.astype(np.double), + fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + fmax=self.mel_fmax, + frame_length=self.win_length, + sr=self.sample_rate, + fill_na=0.0, + ) + return f0 ### Audio Processing ### def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: diff --git a/tests/test_audio_processor.py b/tests/test_audio_processor.py index 22e965f0..d3414286 100644 --- a/tests/test_audio_processor.py +++ b/tests/test_audio_processor.py @@ -181,3 +181,10 @@ class TestAudio(unittest.TestCase): mel_norm = ap.melspectrogram(wav) mel_denorm = ap.denormalize(mel_norm) assert abs(mel_reference - mel_denorm).max() < 1e-4 + + def test_compute_f0(self): + ap = AudioProcessor(**conf) + wav = ap.load_wav(WAV_FILE) + pitch = ap.compute_f0(wav) + mel = ap.melspectrogram(wav) + assert pitch.shape[0] == mel.shape[1] From c8d999b0105281b08b4f6e14b7edea6565310d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 8 Jul 2021 01:28:41 +0200 Subject: [PATCH 03/52] Add FastPitchLoss --- TTS/tts/layers/losses.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ce4ada9..b4866df1 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -658,3 +658,47 @@ class VitsDiscriminatorLoss(nn.Module): loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss return return_dict + + +class FastPitchLoss(nn.Module): + def __init__(self, c): + super().__init__() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + self.pitch_loss = MSELossMasked(False) + + self.spec_loss_alpha = c.spec_loss_alpha + self.ssim_loss_alpha = c.ssim_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.pitch_loss_alpha = c.pitch_loss_alpha + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + ): + + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) + pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens) + loss = ( + self.l1_alpha * l1_loss + + self.ssim_alpha * ssim_loss + + self.huber_alpha * huber_loss + + self.pitch_alpha * pitch_loss + ) + return { + "loss": loss, + "loss_l1": l1_loss, + "loss_ssim": ssim_loss, + "loss_dur": huber_loss, + "loss_pitch": pitch_loss, + } From 994f2be2c11241bf981489fb01cf58eca0014405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 8 Jul 2021 01:29:00 +0200 Subject: [PATCH 04/52] Add comput_f0 field --- TTS/tts/models/base_tts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 922761cb..d39473c7 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -199,6 +199,7 @@ class BaseTTS(BaseModel): outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + comnpute_f0=config.get("compute_f0", False), meta_data=data_items, ap=ap, characters=config.characters, From 0f19f8c911bb367c936b949d3940cc0c23f97767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:28:10 +0200 Subject: [PATCH 05/52] Fix `compute_attention_masks.py` --- TTS/bin/compute_attention_masks.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 3a5c067e..7de3989d 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -8,12 +8,12 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm +from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config +from TTS.utils.io import load_checkpoint if __name__ == "__main__": # pylint: disable=bad-option-value @@ -27,7 +27,7 @@ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json - --dataset_metafile /root/LJSpeech-1.1/metadata.csv + --dataset_metafile metadata.csv --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech @@ -76,8 +76,7 @@ Example run: num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) - model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) - model.eval() + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") @@ -127,9 +126,9 @@ Example run: mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) + model_outputs = model.forward(text_input, text_lengths, mel_input) - alignments = alignments.detach() + alignments = model_outputs["alignments"].detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 From 94e8e0d416ae16e5e77535f9fe13780d5a344d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:29:02 +0200 Subject: [PATCH 06/52] Fix configs --- TTS/config/shared_configs.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index af054346..0de3795c 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -12,60 +12,89 @@ class BaseAudioConfig(Coqpit): Args: fft_size (int): Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + win_length (int): Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match ```fft_size```. Defaults to 1024. + hop_length (int): Number of audio samples between adjacent STFT columns. Defaults to 1024. + frame_shift_ms (int): Set ```hop_length``` based on milliseconds and sampling rate. + frame_length_ms (int): Set ```win_length``` based on milliseconds and sampling rate. + stft_pad_mode (str): Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + sample_rate (int): Audio sampling rate. Defaults to 22050. + resample (bool): Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + preemphasis (float): Preemphasis coefficient. Defaults to 0.0. + ref_level_db (int): 20 Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. Defaults to 20. + do_sound_norm (bool): Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + do_trim_silence (bool): Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + do_amp_to_db_linear (bool, optional): enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + trim_db (int): Silence threshold used for silence trimming. Defaults to 45. + power (float): Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the artifacts in the synthesized voice. Defaults to 1.5. + griffin_lim_iters (int): Number of Griffing Lim iterations. Defaults to 60. + num_mels (int): Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. It needs to be adjusted for a dataset. Defaults to 0. + mel_fmax (float): Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + spec_gain (int): Gain applied when converting amplitude to DB. Defaults to 20. + signal_norm (bool): enable/disable signal normalization. Defaults to True. + min_level_db (int): minimum db threshold for the computed melspectrograms. Defaults to -100. + symmetric_norm (bool): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to True. + max_norm (float): ```k``` defining the normalization range. Defaults to 4.0. + clip_norm (bool): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + stats_path (str): Path to the computed stats file. Defaults to None. """ @@ -298,7 +327,7 @@ class BaseTrainingConfig(Coqpit): keep_all_best: bool = False keep_after: int = 10000 # dataloading - num_loader_workers: int = None + num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False # paths From db32162eaead6ae5f6ee9f34bafbdfebe83acb44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:30:27 +0200 Subject: [PATCH 07/52] Fix `FastPitchLoss` --- TTS/tts/layers/losses.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index b4866df1..8a50c811 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -685,20 +685,20 @@ class FastPitchLoss(nn.Module): input_lens, ): - l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - huber_loss = self.huber(dur_output, dur_target, input_lens) - pitch_loss = self.pitch_loss(pitch_output, pitch_target, input_lens) + dur_loss = self.dur_loss(dur_output[:, : ,None], dur_target[:, :, None], input_lens) + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) loss = ( - self.l1_alpha * l1_loss - + self.ssim_alpha * ssim_loss - + self.huber_alpha * huber_loss - + self.pitch_alpha * pitch_loss + self.spec_loss_alpha * spec_loss + + self.ssim_loss_alpha * ssim_loss + + self.dur_loss_alpha * dur_loss + + self.pitch_loss_alpha * pitch_loss ) return { "loss": loss, - "loss_l1": l1_loss, + "loss_spec": spec_loss, "loss_ssim": ssim_loss, - "loss_dur": huber_loss, + "loss_dur": dur_loss, "loss_pitch": pitch_loss, } From 7590c7db7acf884b16e17e17da6db162d71ea2b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:31:05 +0200 Subject: [PATCH 08/52] Fix `base_tacotron` `aux_input` handling --- TTS/tts/models/base_tacotron.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index 66842305..01291775 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -78,7 +78,9 @@ class BaseTacotron(BaseTTS): @staticmethod def _format_aux_input(aux_input: Dict) -> Dict: - return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None ############################# # INIT FUNCTIONS From d085642ac1b13d296731b1a84f6e3b06f4888f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 18:23:19 +0200 Subject: [PATCH 09/52] Cache pitch features Cache the features at the beginning of `BaseTTS` training. --- TTS/tts/datasets/TTSDataset.py | 95 +++++++++++++++++++++++++++++----- TTS/tts/models/base_tts.py | 11 ++++ TTS/utils/audio.py | 57 ++++++++++++++++---- 3 files changed, 140 insertions(+), 23 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 7ad52797..9b841034 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -9,7 +9,7 @@ import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.utils.audio import AudioProcessor @@ -23,6 +23,7 @@ class TTSDataset(Dataset): ap: AudioProcessor, meta_data: List[List], compute_f0: bool = False, + f0_cache_path: str = None, characters: Dict = None, custom_symbols: List = None, add_blank: bool = False, @@ -41,8 +42,7 @@ class TTSDataset(Dataset): ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can either override or create a new class as the dataset is - initialized by the model. + If you need something different, you can inherit and override. Args: outputs_per_step (int): Number of time frames predicted per step. @@ -57,6 +57,8 @@ class TTSDataset(Dataset): compute_f0 (bool): compute f0 if True. Defaults to False. + f0_cache_path (str): Path to store f0 cache. Defaults to None. + characters (dict): `dict` of custom text characters used for converting texts to sequences. custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own @@ -81,8 +83,8 @@ class TTSDataset(Dataset): use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. - phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in - the coming iterations. Defaults to None. + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. @@ -107,6 +109,7 @@ class TTSDataset(Dataset): self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 + self.f0_cache_path = f0_cache_path self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -123,6 +126,7 @@ class TTSDataset(Dataset): self.verbose = verbose self.input_seq_computed = False self.rescue_item_idx = 1 + self.pitch_computed = False if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -240,10 +244,14 @@ class TTSDataset(Dataset): # TODO: find a better fix return self.load_data(self.rescue_item_idx) + if self.compute_f0: + pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + sample = { "raw_text": raw_text, "text": text, "wav": wav, + "pitch": pitch, "attn": attn, "item_idx": self.items[idx][1], "speaker_name": speaker_name, @@ -260,8 +268,8 @@ class TTSDataset(Dataset): return phonemes def compute_input_seq(self, num_workers=0): - """compute input sequences separately. Call it before - passing dataset to data loader.""" + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" if not self.use_phonemes: if self.verbose: print(" | > Computing input sequences ...") @@ -306,6 +314,64 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def _load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + return True + return False + + def compute_pitch(self, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + for idx, item in enumerate(tqdm.tqdm(self.items)): + self._pitch_worker([item, self.ap, cache_path]) + else: + with Pool(num_workers) as p: + _ = list( + tqdm.tqdm( + p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. @@ -367,7 +433,7 @@ class TTSDataset(Dataset): r""" Perform preprocessing and create a final data batch: 1. Sort batch instances by text-length - 2. Convert Audio signal to Spectrograms. + 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ @@ -466,11 +532,12 @@ class TTSDataset(Dataset): # TODO: compare perf in collate_fn vs in load_data pitch = None if self.compute_f0: - pitch = [self.ap.compute_f0(w).astype("float32") for w in wav] - pitch = prepare_tensor(pitch, self.outputs_per_step) - pitch = pitch.transpose(0, 2, 1) - assert mel.shape[1] == pitch.shape[1] - pitch = torch.FloatTensor(pitch).contiguous() + pitch = [b["pitch"] for b in batch] + pitch = prepare_data(pitch) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None # collate attention alignments if batch[0]["attn"] is not None: @@ -478,6 +545,7 @@ class TTSDataset(Dataset): for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) @@ -499,6 +567,7 @@ class TTSDataset(Dataset): attns, wav_padded, raw_text, + pitch, ) raise TypeError( diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index d39473c7..3a6957f3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -116,6 +116,7 @@ class BaseTTS(BaseModel): speaker_ids = batch[9] attn_mask = batch[10] waveform = batch[11] + pitch = batch[13] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -162,6 +163,7 @@ class BaseTTS(BaseModel): "max_spec_length": float(max_spec_length), "item_idx": item_idx, "waveform": waveform, + "pitch": pitch, } def get_data_loader( @@ -200,6 +202,7 @@ class BaseTTS(BaseModel): text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, comnpute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, characters=config.characters, @@ -246,6 +249,14 @@ class BaseTTS(BaseModel): # sort input sequences from short to long dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) + # compute pitch frames and write to files. + if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]: + dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + + # halt DDP processes for the main process to finish computing the F0 cache + if num_gpus > 1: + dist.barrier() + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index e027b060..3d45b325 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -9,8 +9,7 @@ import torch from torch import nn from TTS.tts.utils.data import StandardScaler - -# import pyworld as pw +from TTS.utils.yin import compute_yin class TorchSTFT(nn.Module): # pylint: disable=abstract-method @@ -648,15 +647,53 @@ class AudioProcessor(object): # frame_period=1000 * self.hop_length / self.sample_rate, # ) # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size) - f0, _, _ = librosa.pyin( - x.astype(np.double), - fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - fmax=self.mel_fmax, - frame_length=self.win_length, - sr=self.sample_rate, - fill_na=0.0, + f0, _, _, _ = compute_yin( + x, + self.sample_rate, + self.win_length, + self.hop_length, + 65 if self.mel_fmin == 0 else self.mel_fmin, + self.mel_fmax, ) + # import pyworld as pw + # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, + # frame_period=self.hop_length / self.sample_rate * 1000) + pad = int((self.win_length / self.hop_length) / 2) + f0 = [0.0] * pad + f0 + [0.0] * pad + f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # f02 = librosa.yin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate + # ) + + # spec = self.melspectrogram(x) + + # from matplotlib import pyplot as plt + # plt.figure() + # plt.plot(f0, linewidth=2.5, color='red') + # plt.plot(f01, linewidth=2.5, linestyle='-.') + # plt.plot(f02, linewidth=2.5) + # plt.xlabel('time') + # plt.ylabel('F0') + # plt.savefig('save_img.png') + + # # plt.figure() + # plt.imshow(spec, aspect="auto", origin="lower") + # plt.savefig('save_img2.png') + # breakpoint() return f0 ### Audio Processing ### From d63a6bb690ca2c912df03d76302193bf5d90f1b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 18:26:25 +0200 Subject: [PATCH 10/52] Set BaseDatasetConfig for tests --- tests/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index c7930ef9..a7878132 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ import os +from TTS.config import BaseDatasetConfig from TTS.utils.generic_utils import get_cuda @@ -30,3 +31,7 @@ def get_tests_output_path(): def run_cli(command): exit_status = os.system(command) assert exit_status == 0, f" [!] command `{command}` failed." + + +def get_test_data_config(): + return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv") From 8fffd4e8132e660a54d0cf6ac205c61c8dfc84bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 13 Jul 2021 10:59:05 +0200 Subject: [PATCH 11/52] Don't print computed phonemes It causes noise in logs --- TTS/tts/datasets/TTSDataset.py | 1 + TTS/tts/layers/losses.py | 2 +- TTS/tts/utils/text/__init__.py | 3 --- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 9b841034..3533dede 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -244,6 +244,7 @@ class TTSDataset(Dataset): # TODO: find a better fix return self.load_data(self.rescue_item_idx) + pitch = None if self.compute_f0: pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 8a50c811..71e7e4fc 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -687,7 +687,7 @@ class FastPitchLoss(nn.Module): spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - dur_loss = self.dur_loss(dur_output[:, : ,None], dur_target[:, :, None], input_lens) + dur_loss = self.dur_loss(dur_output[:, :, None], dur_target[:, :, None], input_lens) pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) loss = ( self.spec_loss_alpha * spec_loss diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 20712f1d..66f518b4 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -45,12 +45,10 @@ def text2phone(text, language, use_espeak_phonemes=False): # TO REVIEW : How to have a good implementation for this? if language == "zh-CN": ph = chinese_text_to_phonemes(text) - print(" > Phonemes: {}".format(ph)) return ph if language == "ja-jp": ph = japanese_text_to_phonemes(text) - print(" > Phonemes: {}".format(ph)) return ph if gruut.is_language_supported(language): @@ -80,7 +78,6 @@ def text2phone(text, language, use_espeak_phonemes=False): # Fix a few phonemes ph = ph.translate(GRUUT_TRANS_TABLE) - return ph raise ValueError(f" [!] Language {language} is not supported for phonemization.") From e802b24ad0c5e4f10796ca2e68847c246fc9331a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 14 Jul 2021 14:33:57 +0200 Subject: [PATCH 12/52] Compute mean and std pitch --- TTS/tts/datasets/TTSDataset.py | 46 +++++++++++++++++++++++++++++----- TTS/tts/models/base_tts.py | 6 +++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 3533dede..f6bd7038 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -127,6 +127,7 @@ class TTSDataset(Dataset): self.input_seq_computed = False self.rescue_item_idx = 1 self.pitch_computed = False + if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -247,6 +248,7 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.normalize_pitch(pitch) sample = { "raw_text": raw_text, @@ -315,6 +317,11 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p + ################ + # Pitch Methods + ############### + # TODO: Refactor Pitch methods into a separate class + @staticmethod def create_pitch_file_path(wav_file, cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] @@ -329,6 +336,19 @@ class TTSDataset(Dataset): np.save(pitch_file, pitch) return pitch + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch -= self.mean + pitch /= self.std + pitch[zero_idxs] = 0.0 + return pitch + @staticmethod def _load_or_compute_pitch(ap, wav_file, cache_path): """ @@ -349,9 +369,9 @@ class TTSDataset(Dataset): _, wav_file, *_ = item pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) if not os.path.exists(pitch_file): - TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - return True - return False + pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None def compute_pitch(self, cache_path, num_workers=0): """Compute the input sequences with multi-processing. @@ -362,16 +382,30 @@ class TTSDataset(Dataset): if self.verbose: print(" | > Computing pitch features ...") if num_workers == 0: - for idx, item in enumerate(tqdm.tqdm(self.items)): - self._pitch_worker([item, self.ap, cache_path]) + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] else: with Pool(num_workers) as p: - _ = list( + pitch_vecs = list( tqdm.tqdm( p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), total=len(self.items), ) ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"] + self.std = stats["std"] + + ################### + # End Pitch Methods + ################### def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 3a6957f3..9e0bf41e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -250,8 +250,10 @@ class BaseTTS(BaseModel): dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) # compute pitch frames and write to files. - if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]: - dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + if config.compute_f0 and rank in [None, 0]: + if not os.path.exists(config.f0_cache_path): + dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + dataset.load_pitch_stats(config.get("f0_cache_path", None)) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: From c448571c3c2ec41a92f49dc94412dedbb247c57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 14 Jul 2021 14:54:01 +0200 Subject: [PATCH 13/52] Add FastPitch LJSpeech recipe --- .../ljspeech/fast_pitch/train_fast_pitch.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 recipes/ljspeech/fast_pitch/train_fast_pitch.py diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py new file mode 100644 index 00000000..e3bd131e --- /dev/null +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -0,0 +1,45 @@ +import os + +from TTS.config import BaseAudioConfig, BaseDatasetConfig +from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.tts.configs import FastPitchConfig + +output_path = os.path.dirname(os.path.abspath(__file__)) +dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/")) +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=False, + trim_db=0.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) +config = FastPitchConfig( + run_name="fast_pitch_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=25, + print_eval=True, + mixed_precision=False, + output_path=output_path, + datasets=[dataset_config] +) +args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger) +trainer.fit() From 5a6ffaee08399b3d10dfd02d0dfff328b84475b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 14 Jul 2021 14:55:54 +0200 Subject: [PATCH 14/52] Add yin based pitch computation --- TTS/utils/yin.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 TTS/utils/yin.py diff --git a/TTS/utils/yin.py b/TTS/utils/yin.py new file mode 100644 index 00000000..3d8bf64b --- /dev/null +++ b/TTS/utils/yin.py @@ -0,0 +1,118 @@ +# adapted from https://github.com/patriceguyot/Yin + +import numpy as np + + +def differenceFunction(x, N, tau_max): + """ + Compute difference function of data x. This corresponds to equation (6) in [1] + This solution is implemented directly with Numpy fft. + + + :param x: audio data + :param N: length of data + :param tau_max: integration window size + :return: difference function + :rtype: list + """ + + x = np.array(x, np.float64) + w = x.size + tau_max = min(tau_max, w) + x_cumsum = np.concatenate((np.array([0.0]), (x * x).cumsum())) + size = w + tau_max + p2 = (size // 32).bit_length() + nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) + size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size) + fc = np.fft.rfft(x, size_pad) + conv = np.fft.irfft(fc * fc.conjugate())[:tau_max] + return x_cumsum[w : w - tau_max : -1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv + + +def cumulativeMeanNormalizedDifferenceFunction(df, N): + """ + Compute cumulative mean normalized difference function (CMND). + + This corresponds to equation (8) in [1] + + :param df: Difference function + :param N: length of data + :return: cumulative mean normalized difference function + :rtype: list + """ + + cmndf = df[1:] * range(1, N) / np.cumsum(df[1:]).astype(float) # scipy method + return np.insert(cmndf, 0, 1) + + +def getPitch(cmdf, tau_min, tau_max, harmo_th=0.1): + """ + Return fundamental period of a frame based on CMND function. + + :param cmdf: Cumulative Mean Normalized Difference function + :param tau_min: minimum period for speech + :param tau_max: maximum period for speech + :param harmo_th: harmonicity threshold to determine if it is necessary to compute pitch frequency + :return: fundamental period if there is values under threshold, 0 otherwise + :rtype: float + """ + tau = tau_min + while tau < tau_max: + if cmdf[tau] < harmo_th: + while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]: + tau += 1 + return tau + tau += 1 + + return 0 # if unvoiced + + +def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500, harmo_thresh=0.1): + """ + + Compute the Yin Algorithm. Return fundamental frequency and harmonic rate. + + :param sig: Audio signal (list of float) + :param sr: sampling rate (int) + :param w_len: size of the analysis window (samples) + :param w_step: size of the lag between two consecutives windows (samples) + :param f0_min: Minimum fundamental frequency that can be detected (hertz) + :param f0_max: Maximum fundamental frequency that can be detected (hertz) + :param harmo_tresh: Threshold of detection. The yalgorithmù return the first minimum of the CMND function below this treshold. + + :returns: + + * pitches: list of fundamental frequencies, + * harmonic_rates: list of harmonic rate values for each fundamental frequency value (= confidence value) + * argmins: minimums of the Cumulative Mean Normalized DifferenceFunction + * times: list of time of each estimation + :rtype: tuple + """ + + tau_min = int(sr / f0_max) + tau_max = int(sr / f0_min) + + timeScale = range(0, len(sig) - w_len, w_step) # time values for each analysis window + times = [t / float(sr) for t in timeScale] + frames = [sig[t : t + w_len] for t in timeScale] + + pitches = [0.0] * len(timeScale) + harmonic_rates = [0.0] * len(timeScale) + argmins = [0.0] * len(timeScale) + + for i, frame in enumerate(frames): + # Compute YIN + df = differenceFunction(frame, w_len, tau_max) + cmdf = cumulativeMeanNormalizedDifferenceFunction(df, tau_max) + p = getPitch(cmdf, tau_min, tau_max, harmo_thresh) + + # Get results + if np.argmin(cmdf) > tau_min: + argmins[i] = float(sr / np.argmin(cmdf)) + if p != 0: # A pitch was found + pitches[i] = float(sr / p) + harmonic_rates[i] = cmdf[p] + else: # No pitch, but we compute a value of the harmonic rate + harmonic_rates[i] = min(cmdf) + + return pitches, harmonic_rates, argmins, times From bc396c393f32cb07874bb573bca4d2244bc80696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 14 Jul 2021 15:04:07 +0200 Subject: [PATCH 15/52] Add FastPitch model and FastPitchconfig --- TTS/tts/configs/fast_pitch_config.py | 98 +++++++ TTS/tts/models/fast_pitch.py | 377 +++++++++++++++++++++++++++ 2 files changed, 475 insertions(+) create mode 100644 TTS/tts/configs/fast_pitch_config.py create mode 100644 TTS/tts/models/fast_pitch.py diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 00000000..88bbd192 --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.fast_pitch import FastPitchArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models. + + Example: + + >>> from TTS.tts.configs import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + huber_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_pitch" + # model specific params + model_args: FastPitchArgs = field(default_factory=FastPitchArgs) + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py new file mode 100644 index 00000000..9b826c3f --- /dev/null +++ b/TTS/tts/models/fast_pitch.py @@ -0,0 +1,377 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +@dataclass +class FastPitchArgs(Coqpit): + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + num_speakers: int = 0 + duration_predictor_hidden_channels: int = 256 + duration_predictor_dropout: float = 0.1 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_dropout: float = 0.1 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + positional_encoding: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + use_d_vector: bool = False + d_vector_dim: int = 0 + + +class FastPitch(BaseTTS): + """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. + + Paper abstract: + We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental + frequency contours. The model predicts pitch contours during inference. By altering these predictions, + the generated speech can be more expressive, better match the semantic of the utterance, and in the end + more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech + that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall + quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead, + and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time + factor for mel-spectrogram synthesis of a typical utterance." + + Notes: + TODO + + Args: + config (Coqpit): Model coqpit class. + + Examples: + >>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs + >>> config = FastPitchArgs() + >>> model = FastPitch(config) + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + + super().__init__() + + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + config.model_args.d_vector_dim, + ) + + if config.model_args.positional_encoding: + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.duration_predictor_hidden_channels, + config.model_args.duration_predictor_kernel_size, + config.model_args.duration_predictor_dropout_p, + ) + + self.pitch_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.pitch_predictor_hidden_channels, + config.model_args.pitch_predictor_kernel_size, + config.model_args.pitch_predictor_dropout_p, + ) + + self.pitch_emb = nn.Conv1d( + 1, + config.model_args.hidden_channels, + kernel_size=config.model_args.pitch_embedding_kernel_size, + padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), + ) + + self.register_buffer("pitch_mean", torch.zeros(1)) + self.register_buffer("pitch_std", torch.zeros(1)) + + if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def expand_encoder_outputs(en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Example: + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_pitch(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def forward( + self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=unused-argument + """ + Shapes: + x: :math:`[B, T_max]` + x_lengths: :math:`[B]` + y_lengths: :math:`[B]` + dr: :math:`[B, T_max]` + g: :math:`[B, C]` + pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "durations_log": o_dr_log.squeeze(1), + "pitch": o_pitch, + "pitch_gt": avg_pitch, + "alignments": attn, + } + return outputs + + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """ + Shapes: + x: [B, T_max] + x_lengths: [B] + g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # input sequence should be greated than the max convolution size + inference_padding = 5 + if x.shape[1] < 13: + inference_padding += 13 - x.shape[1] + # pad input to prevent dropping the last word + x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + # pitch predictor pass + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + # if pitch_transform is not None: + # if self.pitch_std[0] == 0.0: + # # XXX LJSpeech-1.1 defaults + # mean, std = 218.14, 67.24 + # else: + # mean, std = self.pitch_mean[0], self.pitch_std[0] + # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) + + # if pitch_tgt is None: + # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) + # else: + # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) + o_en = o_en + o_pitch_emb + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "pitch": o_pitch, "durations_log": None} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input) + + # compute loss + loss_dict = criterion( + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + torch.log(1 + durations), + outputs["pitch"], + outputs["pitch_gt"], + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel + + return FastPitchLoss(self.config) + + +def average_pitch(pitch, durs): + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) + pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = pitch.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float() + pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float() + + pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems) + return pitch_avg From 545a00fc040d952b076ac8fcea89ab1ba8e4fd28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 16 Jul 2021 12:13:33 +0200 Subject: [PATCH 16/52] Use absolute paths of the attention masks --- TTS/bin/compute_attention_masks.py | 6 ++++-- TTS/tts/datasets/__init__.py | 19 +++++++++++++++---- .../ljspeech/fast_pitch/train_fast_pitch.py | 11 +++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 7de3989d..fc8c6629 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -148,10 +148,12 @@ Example run: alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) - align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output - file_paths.append([item_idx, file_path]) + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) np.save(file_path, alignment) # ourput metafile diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index a2520751..2e315963 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,6 +1,7 @@ import sys from collections import Counter from pathlib import Path +from typing import Dict, List, Tuple import numpy as np @@ -30,7 +31,17 @@ def split_dataset(items): return items[:eval_split_size], items[eval_split_size:] -def load_meta_data(datasets, eval_split=True): +def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]: + """Parse the dataset, load the samples as a list and load the attention alignments if provided. + + Args: + datasets (List[Dict]): A list of dataset dictionaries or dataset configs. + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None for dataset in datasets: @@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True): meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train - # load attention masks for duration predictor training + # load attention masks for the duration predictor training if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): - attn_file = meta_data[ins[1]].strip() + attn_file = meta_data[os.path.abspath(ins[1])].strip() meta_data_train_all[idx].append(attn_file) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): - attn_file = meta_data[ins[1]].strip() + attn_file = meta_data[os.path.abspath(ins[1])].strip() meta_data_eval_all[idx].append(attn_file) return meta_data_train_all, meta_data_eval_all diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index e3bd131e..91fe4bd2 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -3,8 +3,11 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.tts.configs import FastPitchConfig +from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/")) audio_config = BaseAudioConfig( sample_rate=22050, @@ -40,6 +43,14 @@ config = FastPitchConfig( output_path=output_path, datasets=[dataset_config] ) + +# compute alignments +manager = ModelManager() +model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") +# TODO: make compute_attention python callable +os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true") + +# train the model args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer.fit() From aacbb3ed77b1320a8fd47468c3809c67dc5632a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 16 Jul 2021 12:14:25 +0200 Subject: [PATCH 17/52] Fix SpeakerManager usage in `synthesize.py` --- TTS/utils/synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 531523a4..236e78a9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -232,7 +232,7 @@ class Synthesizer(object): # compute a new d_vector from the given clip. if speaker_wav is not None: - speaker_embedding = self.speaker_manager.compute_d_vector_from_clip(speaker_wav) + speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav) use_gl = self.vocoder_model is None From 9af42f78865ec2757dbad9db1386d5540c970244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 17 Jul 2021 01:02:47 +0200 Subject: [PATCH 18/52] Restore `last_epoch` of the scheduler --- TTS/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 68b45fe2..0b4ad308 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -268,9 +268,13 @@ class Trainer: self.config, args.restore_path, self.model, self.optimizer, self.scaler ) - # setup scheduler + + # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) + if self.args.continue_path: + self.scheduler.last_epoch = self.restore_step + # DISTRUBUTED if self.num_gpus > 1: self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) @@ -291,7 +295,6 @@ class Trainer: Returns: nn.Module: initialized model. """ - # TODO: better model setup try: model = setup_vocoder_model(config) except ModuleNotFoundError: From b7caad39e09b664101a584d90ac1b0f31a6a59e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 20 Jul 2021 14:47:12 +0200 Subject: [PATCH 19/52] Make optional to detach duration predictor input --- TTS/tts/models/fast_pitch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 9b826c3f..b6c0e60f 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -44,6 +44,7 @@ class FastPitchArgs(Coqpit): ) use_d_vector: bool = False d_vector_dim: int = 0 + detach_duration_predictor: bool = False class FastPitch(BaseTTS): @@ -237,7 +238,10 @@ class FastPitch(BaseTTS): """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + if self.config.model_args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_en = o_en + o_pitch_emb o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) @@ -250,6 +254,7 @@ class FastPitch(BaseTTS): } return outputs + @torch.no_grad() def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument """ Shapes: @@ -267,7 +272,7 @@ class FastPitch(BaseTTS): x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) # duration predictor pass - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) # pitch predictor pass o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) From 8584f2b82d78c385c70ad52c7ebc754817fbf3c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:19:10 +0200 Subject: [PATCH 20/52] Update docstring format --- TTS/config/shared_configs.py | 5 +++++ TTS/trainer.py | 3 +-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 0de3795c..d91bf2b6 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -176,15 +176,20 @@ class BaseDatasetConfig(Coqpit): Args: name (str): Dataset name that defines the preprocessor in use. Defaults to None. + path (str): Root path to the dataset files. Defaults to None. + meta_file_train (str): Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. Defaults to None. + unused_speakers (List): List of speakers IDs that are not used at the training. Default None. + meta_file_val (str): Name of the dataset meta file that defines the instances used at validation. + meta_file_attn_mask (str): Path to the file that lists the attention mask files used with models that require attention masks to train the duration predictor. diff --git a/TTS/trainer.py b/TTS/trainer.py index 0b4ad308..9bb5b096 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -268,8 +268,7 @@ class Trainer: self.config, args.restore_path, self.model, self.optimizer, self.scaler ) - - # setup scheduler + # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) if self.args.continue_path: From 7692bfe7f85ca9bfdb1a4d019a4ae6846ec249dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:20:11 +0200 Subject: [PATCH 21/52] Update FastPitch config --- TTS/tts/configs/fast_pitch_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 88bbd192..d02e54c9 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -63,15 +63,15 @@ class FastPitchConfig(BaseTTSConfig): d_vector_dim: int = 0 # optimizer parameters - optimizer: str = "RAdam" - optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.98], "weight_decay": 1e-6}) lr_scheduler: str = "NoamLR" lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) lr: float = 1e-4 - grad_clip: float = 5.0 + grad_clip: float = 1000.0 # loss params - ssim_loss_alpha: float = 1.0 + ssim_loss_alpha: float = 0.0 dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 From 57b3aec1b9b65cd0989009ecbc74ef314e8e60e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:20:25 +0200 Subject: [PATCH 22/52] Update docstring format --- TTS/tts/configs/shared_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 8511b1bc..52e337f9 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -103,6 +103,7 @@ class BaseTTSConfig(BaseTrainingConfig): """Shared parameters among all the tts models. Args: + audio (BaseAudioConfig): Audio processor config object instance. From b81560607b0d8f761df46b0f87dc9a496811dee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:20:43 +0200 Subject: [PATCH 23/52] Update docstrings --- TTS/tts/layers/glow_tts/monotonic_align/__init__.py | 6 ++++-- TTS/tts/layers/losses.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 5cbfd8fc..7757ecf8 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -21,8 +21,10 @@ def convert_pad_shape(pad_shape): def generate_path(duration, mask): """ - duration: [b, t_x] - mask: [b, t_x, t_y] + Shapes: + - duration: :math:`[B, T_en]` + - mask: :math:'[B, T_en, T_de]` + - path: :math:`[B, T_en, T_de]` """ device = duration.device b, t_x, t_y = mask.shape diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 71e7e4fc..efe64e2b 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -69,9 +69,9 @@ class MSELossMasked(nn.Module): length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Shapes: - x: B x T X D - target: B x T x D - length: B + - x: :math:`[B, T, D]` + - target: :math:`[B, T, D]` + - length: :math:`B` Returns: loss: An average loss value in range [0, 1] masked by the length. """ From fac9dbe6619f87444a540b943b650ce96b7fab6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:20:54 +0200 Subject: [PATCH 24/52] Update FastPitchLoss --- TTS/tts/layers/losses.py | 41 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index efe64e2b..fefaec9a 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -684,21 +684,28 @@ class FastPitchLoss(nn.Module): pitch_target, input_lens, ): + loss = 0 + return_dict = {} + if self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss += self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss - spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) - ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - dur_loss = self.dur_loss(dur_output[:, :, None], dur_target[:, :, None], input_lens) - pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) - loss = ( - self.spec_loss_alpha * spec_loss - + self.ssim_loss_alpha * ssim_loss - + self.dur_loss_alpha * dur_loss - + self.pitch_loss_alpha * pitch_loss - ) - return { - "loss": loss, - "loss_spec": spec_loss, - "loss_ssim": ssim_loss, - "loss_dur": dur_loss, - "loss_pitch": pitch_loss, - } + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss += self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss += self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss += self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + return_dict["loss"] = loss + return return_dict From 5d59100a883def5563397a78bdc60a3938cb8f72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:21:49 +0200 Subject: [PATCH 25/52] Don't use align_score for models with duration predictor --- TTS/tts/models/align_tts.py | 4 ---- TTS/tts/models/glow_tts.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2aa84cb2..2c3bed3d 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -13,7 +13,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec @@ -355,9 +354,6 @@ class AlignTTS(BaseTTS): phase=self.phase, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log( diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 92c42fa7..e6541871 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -10,7 +10,6 @@ from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -341,9 +340,6 @@ class GlowTTS(BaseTTS): text_lengths, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use From 42862f7fdb39d228dd863a458b0a41b056d5fc94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:24:04 +0200 Subject: [PATCH 26/52] Format style of the recipes --- TTS/tts/utils/measures.py | 2 +- TTS/utils/audio.py | 1 - recipes/ljspeech/fast_pitch/train_fast_pitch.py | 7 ++++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/measures.py b/TTS/tts/utils/measures.py index fdd31242..90e862e1 100644 --- a/TTS/tts/utils/measures.py +++ b/TTS/tts/utils/measures.py @@ -7,7 +7,7 @@ def alignment_diagonal_score(alignments, binary=False): binary (bool): if True, ignore scores and consider attention as a binary mask. Shape: - alignments : batch x decoder_steps x encoder_steps + - alignments : :math:`[B, T_de, T_en]` """ maxs = alignments.max(dim=1)[0] if binary: diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 3d45b325..96b9a1a1 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -693,7 +693,6 @@ class AudioProcessor(object): # # plt.figure() # plt.imshow(spec, aspect="auto", origin="lower") # plt.savefig('save_img2.png') - # breakpoint() return f0 ### Audio Processing ### diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 91fe4bd2..5bc5f448 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -8,7 +8,12 @@ from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) # init configs -dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/")) +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) audio_config = BaseAudioConfig( sample_rate=22050, do_trim_silence=False, From ca29033ef4ad15090f6337af4546b679ab3532e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:24:34 +0200 Subject: [PATCH 27/52] Refactor FastPitch model --- TTS/tts/models/fast_pitch.py | 630 +++++++++++++++++++++++++---------- 1 file changed, 462 insertions(+), 168 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index b6c0e60f..989866ae 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,50 +1,384 @@ from dataclasses import dataclass, field import torch +import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit -from torch import nn -from TTS.tts.layers.feed_forward.decoder import Decoder -from TTS.tts.layers.feed_forward.encoder import Encoder -from TTS.tts.layers.generic.pos_encoding import PositionalEncoding -from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor +# pylint: disable=dangerous-default-value + + +def mask_from_lens(lens, max_len: int = None): + if max_len is None: + max_len = lens.max() + ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) + mask = torch.lt(ids, lens.unsqueeze(1)) + return mask + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + batch_norm=False, + ): + super(ConvNorm, self).__init__() + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None + + torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + if self.norm is None: + return self.conv(signal) + else: + return self.norm(self.conv(signal)) + + +class ConvReLUNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): + super(ConvReLUNorm, self).__init__() + self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2)) + self.norm = torch.nn.LayerNorm(out_channels) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, signal): + out = F.relu(self.conv(signal)) + out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) + return self.dropout(out) + + +class PositionalEmbedding(nn.Module): + def __init__(self, demb): + super(PositionalEmbedding, self).__init__() + self.demb = demb + inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) + if bsz is not None: + return pos_emb[None, :, :].expand(bsz, -1, -1) + else: + return pos_emb[None, :, :] + + +class PositionwiseConvFF(nn.Module): + def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False): + super(PositionwiseConvFF, self).__init__() + + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + + self.CoreNet = nn.Sequential( + nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)), + nn.ReLU(), + # nn.Dropout(dropout), # worse convergence + nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)), + nn.Dropout(dropout), + ) + self.layer_norm = nn.LayerNorm(d_model) + self.pre_lnorm = pre_lnorm + + def forward(self, inp): + return self._forward(inp) + + def _forward(self, inp): + if self.pre_lnorm: + # layer normalization + positionwise feed-forward + core_out = inp.transpose(1, 2) + core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype)) + core_out = core_out.transpose(1, 2) + + # residual connection + output = core_out + inp + else: + # positionwise feed-forward + core_out = inp.transpose(1, 2) + core_out = self.CoreNet(core_out) + core_out = core_out.transpose(1, 2) + + # residual connection + layer normalization + output = self.layer_norm(inp + core_out).to(inp.dtype) + + return output + + +class MultiHeadAttn(nn.Module): + def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False): + super(MultiHeadAttn, self).__init__() + + self.num_heads = num_heads + self.d_model = d_model + self.hidden_channels_head = hidden_channels_head + self.scale = 1 / (hidden_channels_head ** 0.5) + self.pre_lnorm = pre_lnorm + + self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head) + self.drop = nn.Dropout(dropout) + self.dropout_attn = nn.Dropout(dropout_attn) + self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, inp, attn_mask=None): + return self._forward(inp, attn_mask) + + def _forward(self, inp, attn_mask=None): + residual = inp + + if self.pre_lnorm: + # layer normalization + inp = self.layer_norm(inp) + + num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head + + head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) + head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) + head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) + head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) + + q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) + k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) + v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) + + attn_score = torch.bmm(q, k.transpose(1, 2)) + attn_score.mul_(self.scale) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype) + attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1) + attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf")) + + attn_prob = F.softmax(attn_score, dim=2) + attn_prob = self.dropout_attn(attn_prob) + attn_vec = torch.bmm(attn_prob, v) + + attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head) + attn_vec = ( + attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head) + ) + + # linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + # residual connection + output = residual + attn_out + else: + # residual connection + layer normalization + output = self.layer_norm(residual + attn_out) + + output = output.to(attn_out.dtype) + + return output + + +class TransformerLayer(nn.Module): + def __init__( + self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs + ): + super(TransformerLayer, self).__init__() + + self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs) + self.pos_ff = PositionwiseConvFF( + hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm") + ) + + def forward(self, dec_inp, mask=None): + output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2)) + output *= mask + output = self.pos_ff(output) + output *= mask + return output + + +class FFTransformer(nn.Module): + def __init__( + self, + num_layers, + num_heads, + hidden_channels, + hidden_channels_head, + hidden_channels_ffn, + kernel_size, + dropout, + dropout_attn, + dropemb=0.0, + pre_lnorm=False, + ): + super(FFTransformer, self).__init__() + self.hidden_channels = hidden_channels + self.num_heads = num_heads + self.hidden_channels_head = hidden_channels_head + + self.pos_emb = PositionalEmbedding(self.hidden_channels) + self.drop = nn.Dropout(dropemb) + self.layers = nn.ModuleList() + + for _ in range(num_layers): + self.layers.append( + TransformerLayer( + num_heads, + hidden_channels, + hidden_channels_head, + hidden_channels_ffn, + kernel_size, + dropout, + dropout_attn=dropout_attn, + pre_lnorm=pre_lnorm, + ) + ) + + def forward(self, x, x_lengths, conditioning=0): + mask = mask_from_lens(x_lengths).unsqueeze(2) + + pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype) + pos_emb = self.pos_emb(pos_seq) * mask + + if conditioning is None: + conditioning = 0 + + out = self.drop(x + pos_emb + conditioning) + + for layer in self.layers: + out = layer(out, mask=mask) + + # out = self.drop(out) + return out, mask + + +def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): + """If target=None, then predicted durations are applied""" + dtype = enc_out.dtype + reps = durations.float() / pace + reps = (reps + 0.5).long() + dec_lens = reps.sum(dim=1) + + max_len = dec_lens.max() + reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] + reps_cumsum = reps_cumsum.to(dtype) + + range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] + mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) + mult = mult.to(dtype) + en_ex = torch.matmul(mult, enc_out) + + if mel_max_len: + en_ex = en_ex[:, :mel_max_len] + dec_lens = torch.clamp_max(dec_lens, mel_max_len) + return en_ex, dec_lens + + +class TemporalPredictor(nn.Module): + """Predicts a single float per each temporal location""" + + def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2): + super(TemporalPredictor, self).__init__() + + self.layers = nn.Sequential( + *[ + ConvReLUNorm( + input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout + ) + for i in range(num_layers) + ] + ) + self.fc = nn.Linear(filter_size, 1, bias=True) + + def forward(self, enc_out, enc_out_mask): + out = enc_out * enc_out_mask + out = self.layers(out.transpose(1, 2)).transpose(1, 2) + out = self.fc(out) * enc_out_mask + return out.squeeze(-1) + @dataclass class FastPitchArgs(Coqpit): - num_chars: int = None + num_chars: int = 100 out_channels: int = 80 - hidden_channels: int = 256 + hidden_channels: int = 384 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 + duration_predictor_num_layers: int = 2 pitch_predictor_hidden_channels: int = 256 pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 + pitch_predictor_num_layers: int = 2 positional_encoding: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_head": 64, + "hidden_channels_ffn": 1536, + "num_heads": 1, + "num_layers": 6, + "kernel_size": 3, + "dropout": 0.1, + "dropout_attn": 0.1, + } ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_head": 64, + "hidden_channels_ffn": 1536, + "num_heads": 1, + "num_layers": 6, + "kernel_size": 3, + "dropout": 0.1, + "dropout_attn": 0.1, + } ) use_d_vector: bool = False d_vector_dim: int = 0 detach_duration_predictor: bool = False + max_duration: int = 75 + use_gt_duration: bool = True class FastPitch(BaseTTS): @@ -72,71 +406,62 @@ class FastPitch(BaseTTS): >>> model = FastPitch(config) """ - # pylint: disable=dangerous-default-value def __init__(self, config: Coqpit): - super().__init__() - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars + if "characters" in config: + # loading from FasrPitchConfig + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + args = self.config.model_args + else: + # loading from FastPitchArgs + self.config = config + args = config - self.length_scale = ( - float(config.model_args.length_scale) - if isinstance(config.model_args.length_scale, int) - else config.model_args.length_scale + self.max_duration = args.max_duration + self.use_gt_duration = args.use_gt_duration + + self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale + + self.encoder = FFTransformer( + hidden_channels=args.hidden_channels, + **args.encoder_params, ) - self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + # if n_speakers > 1: + # self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim) + # else: + # self.speaker_emb = None + # self.speaker_emb_weight = speaker_emb_weight + self.emb = nn.Embedding(args.num_chars, args.hidden_channels) - self.encoder = Encoder( - config.model_args.hidden_channels, - config.model_args.hidden_channels, - config.model_args.encoder_type, - config.model_args.encoder_params, - config.model_args.d_vector_dim, + self.duration_predictor = TemporalPredictor( + args.hidden_channels, + filter_size=args.duration_predictor_hidden_channels, + kernel_size=args.duration_predictor_kernel_size, + dropout=args.duration_predictor_dropout_p, + num_layers=args.duration_predictor_num_layers, ) - if config.model_args.positional_encoding: - self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params) - self.decoder = Decoder( - config.model_args.out_channels, - config.model_args.hidden_channels, - config.model_args.decoder_type, - config.model_args.decoder_params, - ) - - self.duration_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.duration_predictor_hidden_channels, - config.model_args.duration_predictor_kernel_size, - config.model_args.duration_predictor_dropout_p, - ) - - self.pitch_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.pitch_predictor_hidden_channels, - config.model_args.pitch_predictor_kernel_size, - config.model_args.pitch_predictor_dropout_p, + self.pitch_predictor = TemporalPredictor( + args.hidden_channels, + filter_size=args.pitch_predictor_hidden_channels, + kernel_size=args.pitch_predictor_kernel_size, + dropout=args.pitch_predictor_dropout_p, + num_layers=args.pitch_predictor_num_layers, ) self.pitch_emb = nn.Conv1d( 1, - config.model_args.hidden_channels, - kernel_size=config.model_args.pitch_embedding_kernel_size, - padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), + args.hidden_channels, + kernel_size=args.pitch_embedding_kernel_size, + padding=int((args.pitch_embedding_kernel_size - 1) / 2), ) - self.register_buffer("pitch_mean", torch.zeros(1)) - self.register_buffer("pitch_std", torch.zeros(1)) - - if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: - # speaker embedding layer - self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) - nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - - if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: - self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) + self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) @staticmethod def expand_encoder_outputs(en, dr, x_mask, y_mask): @@ -155,99 +480,52 @@ class FastPitch(BaseTTS): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) - return o_en_ex, attn + o_en_ex = torch.matmul(attn.transpose(1, 2), en) + return o_en_ex, attn.transpose(1, 2) - def format_durations(self, o_dr_log, x_mask): - o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale - o_dr[o_dr < 1] = 1.0 - o_dr = torch.round(o_dr) - return o_dr - - @staticmethod - def _concat_speaker_embedding(o_en, g): - g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] - o_en = torch.cat([o_en, g_exp], 1) - return o_en - - def _sum_speaker_embedding(self, x, g): - # project g to decoder dim. - if hasattr(self, "proj_g"): - g = self.proj_g(g) - return x + g - - def _forward_encoder(self, x, x_lengths, g=None): - if hasattr(self, "emb_g"): - g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] - - if g is not None: - g = g.unsqueeze(-1) - - # [B, T, C] - x_emb = self.emb(x) - # [B, C, T] - x_emb = torch.transpose(x_emb, 1, -1) - - # compute sequence masks + def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}): + speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - # encoder pass - o_en = self.encoder(x_emb, x_mask) + # Calculate speaker embedding + # if self.speaker_emb is None: + # speaker_embedding = 0 + # else: + # speaker_embedding = self.speaker_emb(speaker).unsqueeze(1) + # speaker_embedding.mul_(self.speaker_emb_weight) - # speaker conditioning for duration predictor - if g is not None: - o_en_dp = self._concat_speaker_embedding(o_en, g) - else: - o_en_dp = o_en - return o_en, o_en_dp, x_mask, g + # character embedding + embedding = self.emb(x) - def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) - # expand o_en with durations + # Input FFT + o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) + + # Embedded for predictors + o_en_dr, mask_en_dr = o_en, mask_en + + # Predict durations + o_dr_log = self.duration_predictor(o_en_dr.detach(), mask_en_dr) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + + # TODO: move this to the dataset + avg_pitch = average_pitch(pitch, dr) + + # Predict pitch + o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) + pitch_emb = self.pitch_emb(avg_pitch) + o_en = o_en + pitch_emb.transpose(1, 2) + + # len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len) o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) - # positional encoding - if hasattr(self, "pos_encoder"): - o_en_ex = self.pos_encoder(o_en_ex, y_mask) - # speaker embedding - if g is not None: - o_en_ex = self._sum_speaker_embedding(o_en_ex, g) - # decoder pass - o_de = self.decoder(o_en_ex, y_mask, g=g) - return o_de, attn.transpose(1, 2) - def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): - o_pitch = self.pitch_predictor(o_en, x_mask) - if pitch is not None: - avg_pitch = average_pitch(pitch, dr) - o_pitch_emb = self.pitch_emb(avg_pitch) - return o_pitch_emb, o_pitch, avg_pitch - o_pitch_emb = self.pitch_emb(o_pitch) - return o_pitch_emb, o_pitch - - def forward( - self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": None, "speaker_ids": None} - ): # pylint: disable=unused-argument - """ - Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - y_lengths: :math:`[B]` - dr: :math:`[B, T_max]` - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` - """ - g = aux_input["d_vectors"] if "d_vectors" in aux_input else None - o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - if self.config.model_args.detach_duration_predictor: - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) - else: - o_dr_log = self.duration_predictor(o_en_dp, x_mask) - o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) - o_en = o_en + o_pitch_emb - o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + # Output FFT + o_de, _ = self.decoder(o_en_ex, y_lengths) + o_de = self.proj(o_de) outputs = { - "model_outputs": o_de.transpose(1, 2), + "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), + "durations": o_dr.squeeze(1), "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, @@ -255,43 +533,58 @@ class FastPitch(BaseTTS): return outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument - """ - Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] - """ - g = aux_input["d_vectors"] if "d_vectors" in aux_input else None - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument + speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 + # input sequence should be greated than the max convolution size inference_padding = 5 if x.shape[1] < 13: inference_padding += 13 - x.shape[1] + # pad input to prevent dropping the last word x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) - o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - # duration predictor pass - o_dr_log = self.duration_predictor(o_en_dp, x_mask) - o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) - # pitch predictor pass - o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + + # character embedding + embedding = self.emb(x) + + # if self.speaker_emb is None: + # else: + # speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker + # spk_emb = self.speaker_emb(speaker).unsqueeze(1) + # spk_emb.mul_(self.speaker_emb_weight) + + # Input FFT + o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) + + # Predict durations + o_dr_log = self.duration_predictor(o_en, mask_en) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + + # Pitch over chars + o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) + # if pitch_transform is not None: # if self.pitch_std[0] == 0.0: # # XXX LJSpeech-1.1 defaults # mean, std = 218.14, 67.24 # else: # mean, std = self.pitch_mean[0], self.pitch_std[0] - # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) + # pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std) + + o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2) - # if pitch_tgt is None: - # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) - # else: - # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) o_en = o_en + o_pitch_emb + y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) - outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "pitch": o_pitch, "durations_log": None} + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) + + o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask) + o_de, _ = self.decoder(o_en_ex, y_lengths) + o_de = self.proj(o_de) + + outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log} return outputs def train_step(self, batch: dict, criterion: nn.Module): @@ -319,9 +612,10 @@ class FastPitch(BaseTTS): text_lengths, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use From 81c228a2d87a6384489bd20e9ac6e1682aaa75d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 24 Jul 2021 16:36:52 +0000 Subject: [PATCH 28/52] Update FastPitch don't detach duration network inputs --- TTS/tts/configs/fast_pitch_config.py | 6 +++--- TTS/tts/models/fast_pitch.py | 5 +++-- recipes/ljspeech/fast_pitch/train_fast_pitch.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index d02e54c9..7e294896 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -64,14 +64,14 @@ class FastPitchConfig(BaseTTSConfig): # optimizer parameters optimizer: str = "Adam" - optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.98], "weight_decay": 1e-6}) + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) lr_scheduler: str = "NoamLR" lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) lr: float = 1e-4 - grad_clip: float = 1000.0 + grad_clip: float = 5.0 # loss params - ssim_loss_alpha: float = 0.0 + ssim_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 989866ae..60a1654a 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -505,7 +505,7 @@ class FastPitch(BaseTTS): o_en_dr, mask_en_dr = o_en, mask_en # Predict durations - o_dr_log = self.duration_predictor(o_en_dr.detach(), mask_en_dr) + o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # TODO: move this to the dataset @@ -560,6 +560,7 @@ class FastPitch(BaseTTS): # Predict durations o_dr_log = self.duration_predictor(o_en, mask_en) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + o_dr = o_dr * self.length_scale # Pitch over chars o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) @@ -606,7 +607,7 @@ class FastPitch(BaseTTS): mel_input, mel_lengths, outputs["durations_log"], - torch.log(1 + durations), + durations, outputs["pitch"], outputs["pitch_gt"], text_lengths, diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 5bc5f448..4b852d12 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -42,8 +42,8 @@ config = FastPitchConfig( use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), - print_step=25, - print_eval=True, + print_step=50, + print_eval=False, mixed_precision=False, output_path=output_path, datasets=[dataset_config] From e429afbce4362bba5861eda5e2aa64fbcdb1e7a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 25 Jul 2021 01:24:17 +0200 Subject: [PATCH 29/52] Enable aligner for FastPitch --- TTS/tts/configs/fast_pitch_config.py | 1 + .../glow_tts/monotonic_align/__init__.py | 5 +- TTS/tts/layers/losses.py | 47 +++++++- TTS/tts/models/fast_pitch.py | 110 +++++++++++++++++- .../ljspeech/fast_pitch/train_fast_pitch.py | 15 ++- 5 files changed, 163 insertions(+), 15 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 7e294896..2c54803a 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -76,6 +76,7 @@ class FastPitchConfig(BaseTTSConfig): spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 7757ecf8..ee058095 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -47,8 +47,9 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): """Cython optimised version. - value: [b, t_x, t_y] - mask: [b, t_x, t_y] + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` """ value = value * mask device = value.device diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index fefaec9a..6ca010dd 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -660,6 +660,36 @@ class VitsDiscriminatorLoss(nn.Module): return return_dict +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + class FastPitchLoss(nn.Module): def __init__(self, c): super().__init__() @@ -667,11 +697,14 @@ class FastPitchLoss(nn.Module): self.ssim = SSIMLoss() self.dur_loss = MSELossMasked(False) self.pitch_loss = MSELossMasked(False) + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() self.spec_loss_alpha = c.spec_loss_alpha self.ssim_loss_alpha = c.ssim_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha + self.aligner_loss_alpha = c.aligner_loss_alpha def forward( self, @@ -683,29 +716,35 @@ class FastPitchLoss(nn.Module): pitch_output, pitch_target, input_lens, + alignment_logprob=None, ): loss = 0 return_dict = {} if self.ssim_loss_alpha > 0: ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - loss += self.ssim_loss_alpha * ssim_loss + loss = loss + self.ssim_loss_alpha * ssim_loss return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss if self.spec_loss_alpha > 0: spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) - loss += self.spec_loss_alpha * spec_loss + loss = loss + self.spec_loss_alpha * spec_loss return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss if self.dur_loss_alpha > 0: log_dur_tgt = torch.log(dur_target.float() + 1) dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) - loss += self.dur_loss_alpha * dur_loss + loss = loss + self.dur_loss_alpha * dur_loss return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss if self.pitch_loss_alpha > 0: pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) - loss += self.pitch_loss_alpha * pitch_loss + loss = loss + self.pitch_loss_alpha * pitch_loss return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + if self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss += self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 60a1654a..6f9cee36 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -4,8 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit +from matplotlib.pyplot import plot -from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -14,6 +15,75 @@ from TTS.utils.audio import AudioProcessor # pylint: disable=dangerous-default-value +class AlignmentEncoder(torch.nn.Module): + """Module for alignment text and mel spectrogram.""" + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_proj = nn.Sequential( + ConvNorm( + in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + ), + torch.nn.ReLU(), + ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False), + ) + + self.query_proj = nn.Sequential( + ConvNorm( + in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + ), + torch.nn.ReLU(), + ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False), + torch.nn.ReLU(), + ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False), + ) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ): + """Forward pass of the aligner encoder. + Args: + queries (torch.tensor): query tensor. + keys (torch.tensor): key tensor. + mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain). + attn_prior (torch.tensor): prior for attention matrix. + Shapes: + - queries: :math:`(B, C, T_de)` + - keys: :math:`(B, C_emb, T_en)` + - mask: :math:`(B, T_de)` + Output: + attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. + attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. + """ + keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 + queries_enc = self.query_proj(queries) + + # Simplistic Gaussian Isotopic Attention + attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 + attn = -self.temperature * attn.sum(1, keepdim=True) + + if attn_prior is not None: + attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) + + attn_logprob = attn.clone() + + if mask is not None: + attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + + attn = self.softmax(attn) # softmax along T2 + return attn, attn_logprob + + def mask_from_lens(lens, max_len: int = None): if max_len is None: max_len = lens.max() @@ -379,6 +449,7 @@ class FastPitchArgs(Coqpit): detach_duration_predictor: bool = False max_duration: int = 75 use_gt_duration: bool = True + use_aligner: bool = True class FastPitch(BaseTTS): @@ -421,6 +492,7 @@ class FastPitch(BaseTTS): self.max_duration = args.max_duration self.use_gt_duration = args.use_gt_duration + self.use_aligner = args.use_aligner self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale @@ -463,6 +535,9 @@ class FastPitch(BaseTTS): self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) + if args.use_aligner: + self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) + @staticmethod def expand_encoder_outputs(en, dr, x_mask, y_mask): """Generate attention alignment map from durations and @@ -483,10 +558,16 @@ class FastPitch(BaseTTS): o_en_ex = torch.matmul(attn.transpose(1, 2), en) return o_en_ex, attn.transpose(1, 2) - def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}): + def forward( + self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} + ): speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + o_alignment_dur = None + alignment_logprob = None + alignment_mas = None # Calculate speaker embedding # if self.speaker_emb is None: @@ -508,6 +589,16 @@ class FastPitch(BaseTTS): o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # Aligner + if self.use_aligner: + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1)) + avg_pitch = average_pitch(pitch, o_alignment_dur) + dr = o_alignment_dur + # TODO: move this to the dataset avg_pitch = average_pitch(pitch, dr) @@ -529,6 +620,9 @@ class FastPitch(BaseTTS): "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, } return outputs @@ -599,7 +693,12 @@ class FastPitch(BaseTTS): durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} - outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input) + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + ) + + if self.use_aligner: + durations = outputs["o_alignment_dur"] # compute loss loss_dict = criterion( @@ -611,6 +710,7 @@ class FastPitch(BaseTTS): outputs["pitch"], outputs["pitch_gt"], text_lengths, + outputs["alignment_logprob"], ) # compute duration error @@ -634,6 +734,10 @@ class FastPitch(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } + if self.config.model_args.use_aligner and self.training: + alignment_mas = outputs["alignment_mas"] + figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False) + # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 4b852d12..63f50dd9 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -11,7 +11,7 @@ output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", - meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"), ) audio_config = BaseAudioConfig( @@ -46,14 +46,17 @@ config = FastPitchConfig( print_eval=False, mixed_precision=False, output_path=output_path, - datasets=[dataset_config] + datasets=[dataset_config], ) # compute alignments -manager = ModelManager() -model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") -# TODO: make compute_attention python callable -os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true") +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) # train the model args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) From 98a7271ce8e17ebc9533b0c9e7af8e05b70b4fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 26 Jul 2021 22:50:34 +0200 Subject: [PATCH 30/52] Refactor FastPitchv2 --- TTS/tts/models/fast_pitch.py | 664 ++++++++++------------------------- TTS/utils/audio.py | 42 +-- 2 files changed, 210 insertions(+), 496 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 6f9cee36..c218535e 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,23 +1,23 @@ from dataclasses import dataclass, field import torch -import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit -from matplotlib.pyplot import plot +from torch import nn +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -# pylint: disable=dangerous-default-value - class AlignmentEncoder(torch.nn.Module): - """Module for alignment text and mel spectrogram.""" - def __init__( self, in_query_channels=80, @@ -31,32 +31,35 @@ class AlignmentEncoder(torch.nn.Module): self.log_softmax = torch.nn.LogSoftmax(dim=3) self.key_proj = nn.Sequential( - ConvNorm( - in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, ), torch.nn.ReLU(), - ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), ) self.query_proj = nn.Sequential( - ConvNorm( - in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, ), torch.nn.ReLU(), - ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), torch.nn.ReLU(), - ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), ) def forward( self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None ): """Forward pass of the aligner encoder. - Args: - queries (torch.tensor): query tensor. - keys (torch.tensor): key tensor. - mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain). - attn_prior (torch.tensor): prior for attention matrix. Shapes: - queries: :math:`(B, C, T_de)` - keys: :math:`(B, C_emb, T_en)` @@ -84,365 +87,30 @@ class AlignmentEncoder(torch.nn.Module): return attn, attn_logprob -def mask_from_lens(lens, max_len: int = None): - if max_len is None: - max_len = lens.max() - ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) - mask = torch.lt(ids, lens.unsqueeze(1)) - return mask - - -class LinearNorm(torch.nn.Module): - def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): - super(LinearNorm, self).__init__() - self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) - - torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, x): - return self.linear_layer(x) - - -class ConvNorm(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=None, - dilation=1, - bias=True, - w_init_gain="linear", - batch_norm=False, - ): - super(ConvNorm, self).__init__() - if padding is None: - assert kernel_size % 2 == 1 - padding = int(dilation * (kernel_size - 1) / 2) - - self.conv = torch.nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None - - torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, signal): - if self.norm is None: - return self.conv(signal) - else: - return self.norm(self.conv(signal)) - - -class ConvReLUNorm(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): - super(ConvReLUNorm, self).__init__() - self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2)) - self.norm = torch.nn.LayerNorm(out_channels) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, signal): - out = F.relu(self.conv(signal)) - out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) - return self.dropout(out) - - -class PositionalEmbedding(nn.Module): - def __init__(self, demb): - super(PositionalEmbedding, self).__init__() - self.demb = demb - inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)) - pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) - if bsz is not None: - return pos_emb[None, :, :].expand(bsz, -1, -1) - else: - return pos_emb[None, :, :] - - -class PositionwiseConvFF(nn.Module): - def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False): - super(PositionwiseConvFF, self).__init__() - - self.d_model = d_model - self.d_inner = d_inner - self.dropout = dropout - - self.CoreNet = nn.Sequential( - nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)), - nn.ReLU(), - # nn.Dropout(dropout), # worse convergence - nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)), - nn.Dropout(dropout), - ) - self.layer_norm = nn.LayerNorm(d_model) - self.pre_lnorm = pre_lnorm - - def forward(self, inp): - return self._forward(inp) - - def _forward(self, inp): - if self.pre_lnorm: - # layer normalization + positionwise feed-forward - core_out = inp.transpose(1, 2) - core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype)) - core_out = core_out.transpose(1, 2) - - # residual connection - output = core_out + inp - else: - # positionwise feed-forward - core_out = inp.transpose(1, 2) - core_out = self.CoreNet(core_out) - core_out = core_out.transpose(1, 2) - - # residual connection + layer normalization - output = self.layer_norm(inp + core_out).to(inp.dtype) - - return output - - -class MultiHeadAttn(nn.Module): - def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False): - super(MultiHeadAttn, self).__init__() - - self.num_heads = num_heads - self.d_model = d_model - self.hidden_channels_head = hidden_channels_head - self.scale = 1 / (hidden_channels_head ** 0.5) - self.pre_lnorm = pre_lnorm - - self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head) - self.drop = nn.Dropout(dropout) - self.dropout_attn = nn.Dropout(dropout_attn) - self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False) - self.layer_norm = nn.LayerNorm(d_model) - - def forward(self, inp, attn_mask=None): - return self._forward(inp, attn_mask) - - def _forward(self, inp, attn_mask=None): - residual = inp - - if self.pre_lnorm: - # layer normalization - inp = self.layer_norm(inp) - - num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head - - head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) - head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - - q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - - attn_score = torch.bmm(q, k.transpose(1, 2)) - attn_score.mul_(self.scale) - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype) - attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1) - attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf")) - - attn_prob = F.softmax(attn_score, dim=2) - attn_prob = self.dropout_attn(attn_prob) - attn_vec = torch.bmm(attn_prob, v) - - attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head) - attn_vec = ( - attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head) - ) - - # linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out) - - if self.pre_lnorm: - # residual connection - output = residual + attn_out - else: - # residual connection + layer normalization - output = self.layer_norm(residual + attn_out) - - output = output.to(attn_out.dtype) - - return output - - -class TransformerLayer(nn.Module): - def __init__( - self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs - ): - super(TransformerLayer, self).__init__() - - self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs) - self.pos_ff = PositionwiseConvFF( - hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm") - ) - - def forward(self, dec_inp, mask=None): - output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2)) - output *= mask - output = self.pos_ff(output) - output *= mask - return output - - -class FFTransformer(nn.Module): - def __init__( - self, - num_layers, - num_heads, - hidden_channels, - hidden_channels_head, - hidden_channels_ffn, - kernel_size, - dropout, - dropout_attn, - dropemb=0.0, - pre_lnorm=False, - ): - super(FFTransformer, self).__init__() - self.hidden_channels = hidden_channels - self.num_heads = num_heads - self.hidden_channels_head = hidden_channels_head - - self.pos_emb = PositionalEmbedding(self.hidden_channels) - self.drop = nn.Dropout(dropemb) - self.layers = nn.ModuleList() - - for _ in range(num_layers): - self.layers.append( - TransformerLayer( - num_heads, - hidden_channels, - hidden_channels_head, - hidden_channels_ffn, - kernel_size, - dropout, - dropout_attn=dropout_attn, - pre_lnorm=pre_lnorm, - ) - ) - - def forward(self, x, x_lengths, conditioning=0): - mask = mask_from_lens(x_lengths).unsqueeze(2) - - pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype) - pos_emb = self.pos_emb(pos_seq) * mask - - if conditioning is None: - conditioning = 0 - - out = self.drop(x + pos_emb + conditioning) - - for layer in self.layers: - out = layer(out, mask=mask) - - # out = self.drop(out) - return out, mask - - -def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): - """If target=None, then predicted durations are applied""" - dtype = enc_out.dtype - reps = durations.float() / pace - reps = (reps + 0.5).long() - dec_lens = reps.sum(dim=1) - - max_len = dec_lens.max() - reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] - reps_cumsum = reps_cumsum.to(dtype) - - range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] - mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) - mult = mult.to(dtype) - en_ex = torch.matmul(mult, enc_out) - - if mel_max_len: - en_ex = en_ex[:, :mel_max_len] - dec_lens = torch.clamp_max(dec_lens, mel_max_len) - return en_ex, dec_lens - - -class TemporalPredictor(nn.Module): - """Predicts a single float per each temporal location""" - - def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2): - super(TemporalPredictor, self).__init__() - - self.layers = nn.Sequential( - *[ - ConvReLUNorm( - input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout - ) - for i in range(num_layers) - ] - ) - self.fc = nn.Linear(filter_size, 1, bias=True) - - def forward(self, enc_out, enc_out_mask): - out = enc_out * enc_out_mask - out = self.layers(out.transpose(1, 2)).transpose(1, 2) - out = self.fc(out) * enc_out_mask - return out.squeeze(-1) - - @dataclass class FastPitchArgs(Coqpit): - num_chars: int = 100 + num_chars: int = None out_channels: int = 80 - hidden_channels: int = 384 + hidden_channels: int = 256 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 - duration_predictor_num_layers: int = 2 pitch_predictor_hidden_channels: int = 256 pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 - pitch_predictor_num_layers: int = 2 positional_encoding: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: { - "hidden_channels_head": 64, - "hidden_channels_ffn": 1536, - "num_heads": 1, - "num_layers": 6, - "kernel_size": 3, - "dropout": 0.1, - "dropout_attn": 0.1, - } + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: { - "hidden_channels_head": 64, - "hidden_channels_ffn": 1536, - "num_heads": 1, - "num_layers": 6, - "kernel_size": 3, - "dropout": 0.1, - "dropout_attn": 0.1, - } + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} ) use_d_vector: bool = False d_vector_dim: int = 0 @@ -477,7 +145,9 @@ class FastPitch(BaseTTS): >>> model = FastPitch(config) """ + # pylint: disable=dangerous-default-value def __init__(self, config: Coqpit): + super().__init__() if "characters" in config: @@ -496,44 +166,54 @@ class FastPitch(BaseTTS): self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale - self.encoder = FFTransformer( - hidden_channels=args.hidden_channels, - **args.encoder_params, + self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + config.model_args.d_vector_dim, ) - # if n_speakers > 1: - # self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim) - # else: - # self.speaker_emb = None - # self.speaker_emb_weight = speaker_emb_weight - self.emb = nn.Embedding(args.num_chars, args.hidden_channels) + if config.model_args.positional_encoding: + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) - self.duration_predictor = TemporalPredictor( - args.hidden_channels, - filter_size=args.duration_predictor_hidden_channels, - kernel_size=args.duration_predictor_kernel_size, - dropout=args.duration_predictor_dropout_p, - num_layers=args.duration_predictor_num_layers, + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, ) - self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params) + self.duration_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.duration_predictor_hidden_channels, + config.model_args.duration_predictor_kernel_size, + config.model_args.duration_predictor_dropout_p, + ) - self.pitch_predictor = TemporalPredictor( - args.hidden_channels, - filter_size=args.pitch_predictor_hidden_channels, - kernel_size=args.pitch_predictor_kernel_size, - dropout=args.pitch_predictor_dropout_p, - num_layers=args.pitch_predictor_num_layers, + self.pitch_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.pitch_predictor_hidden_channels, + config.model_args.pitch_predictor_kernel_size, + config.model_args.pitch_predictor_dropout_p, ) self.pitch_emb = nn.Conv1d( 1, - args.hidden_channels, - kernel_size=args.pitch_embedding_kernel_size, - padding=int((args.pitch_embedding_kernel_size - 1) / 2), + config.model_args.hidden_channels, + kernel_size=config.model_args.pitch_embedding_kernel_size, + padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), ) - self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) + if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) if args.use_aligner: self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) @@ -555,64 +235,109 @@ class FastPitch(BaseTTS): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.transpose(1, 2), en) - return o_en_ex, attn.transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g, x_emb + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_pitch(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner(self, y, embedding, x_mask, y_mask): + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1) + return o_alignment_dur, alignment_logprob, alignment_mas def forward( self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} - ): - speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 + ): # pylint: disable=unused-argument + """ + Shapes: + x: :math:`[B, T_max]` + x_lengths: :math:`[B]` + y_lengths: :math:`[B]` + y: :math:`[B, T_max2]` + dr: :math:`[B, T_max]` + g: :math:`[B, C]` + pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - o_alignment_dur = None - alignment_logprob = None - alignment_mas = None - - # Calculate speaker embedding - # if self.speaker_emb is None: - # speaker_embedding = 0 - # else: - # speaker_embedding = self.speaker_emb(speaker).unsqueeze(1) - # speaker_embedding.mul_(self.speaker_emb_weight) - - # character embedding - embedding = self.emb(x) - - # Input FFT - o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) - - # Embedded for predictors - o_en_dr, mask_en_dr = o_en, mask_en - - # Predict durations - o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g) + if self.config.model_args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) - - # Aligner if self.use_aligner: - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) - alignment_mas = maximum_path( - alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() - ) - o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1)) - avg_pitch = average_pitch(pitch, o_alignment_dur) + o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask) dr = o_alignment_dur - - # TODO: move this to the dataset - avg_pitch = average_pitch(pitch, dr) - - # Predict pitch - o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) - pitch_emb = self.pitch_emb(avg_pitch) - o_en = o_en + pitch_emb.transpose(1, 2) - - # len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len) - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) - - # Output FFT - o_de, _ = self.decoder(o_en_ex, y_lengths) - o_de = self.proj(o_de) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) outputs = { "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), @@ -620,66 +345,55 @@ class FastPitch(BaseTTS): "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, - "alignment_mas": alignment_mas, + "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, } return outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument - speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 - + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """ + Shapes: + x: [B, T_max] + x_lengths: [B] + g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # input sequence should be greated than the max convolution size inference_padding = 5 if x.shape[1] < 13: inference_padding += 13 - x.shape[1] - # pad input to prevent dropping the last word x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - - # character embedding - embedding = self.emb(x) - - # if self.speaker_emb is None: - # else: - # speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker - # spk_emb = self.speaker_emb(speaker).unsqueeze(1) - # spk_emb.mul_(self.speaker_emb_weight) - - # Input FFT - o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) - - # Predict durations - o_dr_log = self.duration_predictor(o_en, mask_en) - o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) - o_dr = o_dr * self.length_scale - - # Pitch over chars - o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) - + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + # pitch predictor pass + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) # if pitch_transform is not None: # if self.pitch_std[0] == 0.0: # # XXX LJSpeech-1.1 defaults # mean, std = 218.14, 67.24 # else: # mean, std = self.pitch_mean[0], self.pitch_std[0] - # pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std) - - o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2) + # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) + # if pitch_tgt is None: + # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) + # else: + # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) o_en = o_en + o_pitch_emb - y_lengths = o_dr.sum(1) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - - o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask) - o_de, _ = self.decoder(o_en_ex, y_lengths) - o_de = self.proj(o_de) - - outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log} + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } return outputs def train_step(self, batch: dict, criterion: nn.Module): @@ -735,8 +449,8 @@ class FastPitch(BaseTTS): } if self.config.model_args.use_aligner and self.training: - alignment_mas = outputs["alignment_mas"] - figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False) + alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy() + figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False) # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 96b9a1a1..6a74b3c8 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -647,29 +647,29 @@ class AudioProcessor(object): # frame_period=1000 * self.hop_length / self.sample_rate, # ) # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - f0, _, _, _ = compute_yin( - x, - self.sample_rate, - self.win_length, - self.hop_length, - 65 if self.mel_fmin == 0 else self.mel_fmin, - self.mel_fmax, - ) - # import pyworld as pw - # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, - # frame_period=self.hop_length / self.sample_rate * 1000) - pad = int((self.win_length / self.hop_length) / 2) - f0 = [0.0] * pad + f0 + [0.0] * pad - f0 = np.array(f0, dtype=np.float32) - - # f01, _, _ = librosa.pyin( + # f0, _, _, _ = compute_yin( # x, - # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - # fmax=self.mel_fmax, - # frame_length=self.win_length, - # sr=self.sample_rate, - # fill_na=0.0, + # self.sample_rate, + # self.win_length, + # self.hop_length, + # 65 if self.mel_fmin == 0 else self.mel_fmin, + # self.mel_fmax, # ) + # # import pyworld as pw + # # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, + # # frame_period=self.hop_length / self.sample_rate * 1000) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.array(f0, dtype=np.float32) + + f0, _, _ = librosa.pyin( + x, + fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + fmax=self.mel_fmax, + frame_length=self.win_length, + sr=self.sample_rate, + fill_na=0.0, + ) # f02 = librosa.yin( # x, From 59d52a4cd8659fea158aeec925bc0ffa6f505694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 1 Sep 2021 08:31:33 +0000 Subject: [PATCH 31/52] Disable autcast for criterions --- TTS/tts/models/fast_pitch.py | 76 +++++++++++++++++------------------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index c218535e..b8f346c7 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field +from typing import Tuple import torch -import torch.nn.functional as F from coqpit import Coqpit from torch import nn +from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder @@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor @@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module): self.softmax = torch.nn.Softmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.key_proj = nn.Sequential( + self.key_layer = nn.Sequential( nn.Conv1d( in_key_channels, in_key_channels * 2, @@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module): nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), ) - self.query_proj = nn.Sequential( + self.query_layer = nn.Sequential( nn.Conv1d( in_query_channels, in_query_channels * 2, @@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module): def forward( self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None - ): + ) -> Tuple[torch.tensor, torch.tensor]: """Forward pass of the aligner encoder. Shapes: - - queries: :math:`(B, C, T_de)` - - keys: :math:`(B, C_emb, T_en)` - - mask: :math:`(B, T_de)` + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` Output: - attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. - attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. """ - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - queries_enc = self.query_proj(queries) - - # Simplistic Gaussian Isotopic Attention - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 - attn = -self.temperature * attn.sum(1, keepdim=True) - + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True) if attn_prior is not None: - attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) - - attn_logprob = attn.clone() - + attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8) if mask is not None: - attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) - - attn = self.softmax(attn) # softmax along T2 - return attn, attn_logprob + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) + return attn, attn_logp @dataclass @@ -414,23 +407,24 @@ class FastPitch(BaseTTS): if self.use_aligner: durations = outputs["o_alignment_dur"] - # compute loss - loss_dict = criterion( - outputs["model_outputs"], - mel_input, - mel_lengths, - outputs["durations_log"], - durations, - outputs["pitch"], - outputs["pitch_gt"], - text_lengths, - outputs["alignment_logprob"], - ) + with autocast(enabled=False): # use float32 for the criterion + # compute loss + loss_dict = criterion( + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + durations, + outputs["pitch"], + outputs["pitch_gt"], + text_lengths, + outputs["alignment_logprob"], + ) - # compute duration error - durations_pred = outputs["durations"] - duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() - loss_dict["duration_error"] = duration_error + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use From 6e9d4062f2e37240fefc5e04add8663c443c06bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:20:26 +0000 Subject: [PATCH 32/52] Add `sort_by_audio_len` option --- TTS/tts/configs/shared_configs.py | 8 ++++++-- TTS/tts/configs/vits_config.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 52e337f9..3dc70786 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -141,11 +141,14 @@ class BaseTTSConfig(BaseTrainingConfig): loss_masking (bool): enable / disable masking loss values against padded segments of samples in a batch. + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + min_seq_len (int): - Minimum input sequence length to be used at training. + Minimum sequence length to be used at training. max_seq_len (int): - Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + Maximum sequence length to be used at training. Larger values result in more VRAM usage. compute_f0 (int): (Not in use yet). @@ -198,6 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig): batch_group_size: int = 0 loss_masking: bool = None # dataloading + sort_by_audio_len: bool = True min_seq_len: int = 1 max_seq_len: int = float("inf") compute_f0: bool = False diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 58fc66ee..39479231 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -67,11 +67,14 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + min_seq_len (int): - Minimum text length to be considered for training. Defaults to `13`. + Minimum sequnce length to be considered for training. Defaults to `0`. max_seq_len (int): - Maximum text length to be considered for training. Defaults to `500`. + Maximum sequnce length to be considered for training. Defaults to `500000`. r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. From debf772ec5128963a0c3bec8e4f6bcaafe981221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:23:22 +0000 Subject: [PATCH 33/52] Implement binary alignment loss --- TTS/tts/configs/fast_pitch_config.py | 23 +++++++++++++++++++++++ TTS/tts/layers/losses.py | 17 ++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 2c54803a..873f298e 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig): Args: model (str): Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + model_args (Coqpit): Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + data_dep_init_steps (int): Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses Activation Normalization that pre-computes normalization stats at the beginning and use the same values for the rest. Defaults to 10. + use_speaker_embedding (bool): enable / disable using speaker embeddings for multi-speaker models. If set True, the model is in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): Initial learning rate. Defaults to `1e-3`. + wd (float): Weight decay coefficient. Defaults to `1e-7`. + ssim_loss_alpha (float): Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + huber_loss_alpha (float): Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + spec_loss_alpha (float): Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + pitch_loss_alpha (float): Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + min_seq_len (int): Minimum input sequence length to be used at training. + max_seq_len (int): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ @@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig): pitch_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 6ca010dd..805f36d6 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module): self.dur_loss_alpha = c.dur_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha self.aligner_loss_alpha = c.aligner_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + + def _binary_alignment_loss(self, alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() def forward( self, @@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module): pitch_target, input_lens, alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, ): loss = 0 return_dict = {} @@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module): if self.aligner_loss_alpha > 0: aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) - loss += self.aligner_loss_alpha * aligner_loss + loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + if self.binary_alignment_loss_alpha > 0: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + return_dict["loss"] = loss return return_dict From 648655fa0366917e078d5a52396c792578662a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:25:57 +0000 Subject: [PATCH 34/52] Add `PitchExtractor` and return dict by `collate` --- TTS/tts/datasets/TTSDataset.py | 228 +++++++++++++++++---------------- TTS/tts/models/base_tts.py | 32 ++--- TTS/tts/models/glow_tts.py | 4 +- 3 files changed, 138 insertions(+), 126 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index f6bd7038..74cb8de1 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -130,6 +130,8 @@ class TTSDataset(Dataset): if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) + if compute_f0: + self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) if self.verbose: print("\n > DataLoader initialization") print(" | > Use phonemes: {}".format(self.use_phonemes)) @@ -247,8 +249,8 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: - pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) - pitch = self.normalize_pitch(pitch) + pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.normalize_pitch(pitch) sample = { "raw_text": raw_text, @@ -317,96 +319,6 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p - ################ - # Pitch Methods - ############### - # TODO: Refactor Pitch methods into a separate class - - @staticmethod - def create_pitch_file_path(wav_file, cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file - - @staticmethod - def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) - if pitch_file: - np.save(pitch_file, pitch) - return pitch - - @staticmethod - def compute_pitch_stats(pitch_vecs): - nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) - mean, std = np.mean(nonzeros), np.std(nonzeros) - return mean, std - - def normalize_pitch(self, pitch): - zero_idxs = np.where(pitch == 0.0)[0] - pitch -= self.mean - pitch /= self.std - pitch[zero_idxs] = 0.0 - return pitch - - @staticmethod - def _load_or_compute_pitch(ap, wav_file, cache_path): - """ - compute pitch and return a numpy array of pitch values - """ - pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - else: - pitch = np.load(pitch_file) - return pitch - - @staticmethod - def _pitch_worker(args): - item = args[0] - ap = args[1] - cache_path = args[2] - _, wav_file, *_ = item - pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path) - if not os.path.exists(pitch_file): - pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file) - return pitch - return None - - def compute_pitch(self, cache_path, num_workers=0): - """Compute the input sequences with multi-processing. - Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" - if not os.path.exists(cache_path): - os.makedirs(cache_path, exist_ok=True) - - if self.verbose: - print(" | > Computing pitch features ...") - if num_workers == 0: - pitch_vecs = [] - for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] - else: - with Pool(num_workers) as p: - pitch_vecs = list( - tqdm.tqdm( - p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), - total=len(self.items), - ) - ) - pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) - pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) - - def load_pitch_stats(self, cache_path): - stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"] - self.std = stats["std"] - - ################### - # End Pitch Methods - ################### - def sort_and_filter_items(self, by_audio_len=False): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. @@ -588,22 +500,22 @@ class TTSDataset(Dataset): else: attns = None # TODO: return dictionary - return ( - text, - text_lenghts, - speaker_names, - linear, - mel, - mel_lengths, - stop_targets, - item_idxs, - d_vectors, - speaker_ids, - attns, - wav_padded, - raw_text, - pitch, - ) + return { + "text": text, + "text_lengths": text_lenghts, + "speaker_names": speaker_names, + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": item_idxs, + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": raw_text, + "pitch": pitch, + } raise TypeError( ( @@ -613,3 +525,103 @@ class TTSDataset(Dataset): ) ) ) + + +class PitchExtractor: + """Pitch Extractor for computing F0 from wav files. + + Args: + items (List[List]): Dataset samples. + verbose (bool): Whether to print the progress. + """ + + def __init__( + self, + items: List[List], + verbose=False, + ): + self.items = items + self.verbose = verbose + self.mean = None + self.std = None + + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch -= self.mean + pitch /= self.std + pitch[zero_idxs] = 0.0 + return pitch + + @staticmethod + def _load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None + + def compute_pitch(self, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] + else: + with Pool(num_workers) as p: + pitch_vecs = list( + tqdm.tqdm( + p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"] + self.std = stats["std"] diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9e0bf41e..653143cd 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -104,19 +104,19 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch[0] - text_lengths = batch[1] - speaker_names = batch[2] - linear_input = batch[3] - mel_input = batch[4] - mel_lengths = batch[5] - stop_targets = batch[6] - item_idx = batch[7] - d_vectors = batch[8] - speaker_ids = batch[9] - attn_mask = batch[10] - waveform = batch[11] - pitch = batch[13] + text_input = batch["text"] + text_lengths = batch["text_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -201,7 +201,7 @@ class BaseTTS(BaseModel): outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, - comnpute_f0=config.get("compute_f0", False), + compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, @@ -252,8 +252,8 @@ class BaseTTS(BaseModel): # compute pitch frames and write to files. if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): - dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) - dataset.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) + dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e6541871..27012207 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -134,9 +134,9 @@ class GlowTTS(BaseTTS): """ Shapes: - x: :math:`[B, T]` - - x_lenghts::math:` B` + - x_lenghts::math:`B` - y: :math:`[B, T, C]` - - y_lengths::math:` B` + - y_lengths::math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) From 59b24e66cf13145819f0873be6f1314949f9e092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:26:24 +0000 Subject: [PATCH 35/52] Add `AlignerNetwork` --- TTS/tts/layers/generic/aligner.py | 81 +++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 TTS/tts/layers/generic/aligner.py diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py new file mode 100644 index 00000000..eef4c4b6 --- /dev/null +++ b/TTS/tts/layers/generic/aligner.py @@ -0,0 +1,81 @@ +from typing import Tuple + +import torch +from torch import nn + + +class AlignmentNetwork(torch.nn.Module): + """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. + + :: + + query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment + key -> conv1d -> relu -> conv1d -----------------------^ + + Args: + in_query_channels (int): Number of channels in the query network. Defaults to 80. + in_key_channels (int): Number of channels in the key network. Defaults to 512. + attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. + temperature (float): Temperature for the softmax. Defaults to 0.0005. + """ + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_layer = nn.Sequential( + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.query_layer = nn.Sequential( + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward pass of the aligner encoder. + Shapes: + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` + Output: + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. + """ + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) + if attn_prior is not None: + attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + if mask is not None: + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) + return attn, attn_logp From 2bf9e83c498d4f2b74096bf0aaece60008044db7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:26:49 +0000 Subject: [PATCH 36/52] FastPitch refactor and commenting --- TTS/tts/layers/losses.py | 2 +- TTS/tts/models/fast_pitch.py | 595 ++++++++++++------ docs/source/index.md | 1 + .../ljspeech/fast_pitch/train_fast_pitch.py | 9 +- 4 files changed, 403 insertions(+), 204 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 805f36d6..100b8fb3 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -756,7 +756,7 @@ class FastPitchLoss(nn.Module): loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss - if self.binary_alignment_loss_alpha > 0: + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index b8f346c7..352aebfa 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Tuple +from typing import Dict, Tuple import torch from coqpit import Coqpit @@ -8,6 +8,7 @@ from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path @@ -15,87 +16,101 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor - - -class AlignmentEncoder(torch.nn.Module): - def __init__( - self, - in_query_channels=80, - in_key_channels=512, - attn_channels=80, - temperature=0.0005, - ): - super().__init__() - self.temperature = temperature - self.softmax = torch.nn.Softmax(dim=3) - self.log_softmax = torch.nn.LogSoftmax(dim=3) - - self.key_layer = nn.Sequential( - nn.Conv1d( - in_key_channels, - in_key_channels * 2, - kernel_size=3, - padding=1, - bias=True, - ), - torch.nn.ReLU(), - nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), - ) - - self.query_layer = nn.Sequential( - nn.Conv1d( - in_query_channels, - in_query_channels * 2, - kernel_size=3, - padding=1, - bias=True, - ), - torch.nn.ReLU(), - nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), - torch.nn.ReLU(), - nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), - ) - - def forward( - self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None - ) -> Tuple[torch.tensor, torch.tensor]: - """Forward pass of the aligner encoder. - Shapes: - - queries: :math:`[B, C, T_de]` - - keys: :math:`[B, C_emb, T_en]` - - mask: :math:`[B, T_de]` - Output: - attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. - attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. - """ - key_out = self.key_layer(keys) - query_out = self.query_layer(queries) - attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 - attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True) - if attn_prior is not None: - attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8) - if mask is not None: - attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) - attn = self.softmax(attn_logp) - return attn, attn_logp +from TTS.utils.soft_dtw import SoftDTW @dataclass class FastPitchArgs(Coqpit): + """Fast Pitch Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + use_d_vetor (bool): + Whether to use precomputed d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of channels of the d-vectors. Defaults to 0. + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + use_aligner (bool): + Use aligner network to learn the text to speech alignment. Defaults to True. + """ + num_chars: int = None out_channels: int = 80 - hidden_channels: int = 256 + hidden_channels: int = 384 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 - duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 pitch_predictor_hidden_channels: int = 256 - pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( @@ -109,14 +124,16 @@ class FastPitchArgs(Coqpit): d_vector_dim: int = 0 detach_duration_predictor: bool = False max_duration: int = 75 - use_gt_duration: bool = True use_aligner: bool = True class FastPitch(BaseTTS): """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. - Paper abstract: + Paper:: + https://arxiv.org/abs/2006.06873 + + Paper abstract:: We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental frequency contours. The model predicts pitch contours during inference. By altering these predictions, the generated speech can be more expressive, better match the semantic of the utterance, and in the end @@ -126,9 +143,6 @@ class FastPitch(BaseTTS): and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time factor for mel-spectrogram synthesis of a typical utterance." - Notes: - TODO - Args: config (Coqpit): Model coqpit class. @@ -143,95 +157,138 @@ class FastPitch(BaseTTS): super().__init__() - if "characters" in config: - # loading from FasrPitchConfig - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - args = self.config.model_args - else: - # loading from FastPitchArgs + # don't use isintance not to import recursively + if config.__class__.__name__ == "FastPitchConfig": + if "characters" in config: + # loading from FasrPitchConfig + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + # loading from FastPitchArgs + self.config = config + self.args = config.model_args + elif isinstance(config, FastPitchArgs): + self.args = config self.config = config - args = config + else: + raise ValueError("config must be either a VitsConfig or Vitsself.args") - self.max_duration = args.max_duration - self.use_gt_duration = args.use_gt_duration - self.use_aligner = args.use_aligner + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_binary_alignment_loss = False - self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale - - self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) - - self.encoder = Encoder( - config.model_args.hidden_channels, - config.model_args.hidden_channels, - config.model_args.encoder_type, - config.model_args.encoder_params, - config.model_args.d_vector_dim, + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale ) - if config.model_args.positional_encoding: - self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.args.d_vector_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) self.decoder = Decoder( - config.model_args.out_channels, - config.model_args.hidden_channels, - config.model_args.decoder_type, - config.model_args.decoder_params, + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, ) self.duration_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.duration_predictor_hidden_channels, - config.model_args.duration_predictor_kernel_size, - config.model_args.duration_predictor_dropout_p, + self.args.hidden_channels + self.args.d_vector_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, ) self.pitch_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.pitch_predictor_hidden_channels, - config.model_args.pitch_predictor_kernel_size, - config.model_args.pitch_predictor_dropout_p, + self.args.hidden_channels + self.args.d_vector_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, ) self.pitch_emb = nn.Conv1d( 1, - config.model_args.hidden_channels, - kernel_size=config.model_args.pitch_embedding_kernel_size, - padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), ) - if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + if self.args.num_speakers > 1 and not self.args.use_d_vector: # speaker embedding layer - self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: - self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) + if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) - if args.use_aligner: - self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) + if self.args.use_aligner: + self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels) @staticmethod - def expand_encoder_outputs(en, dr, x_mask, y_mask): + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): """Generate attention alignment map from durations and expand encoder outputs - Example: - encoder output: [a,b,c,d] - durations: [1, 3, 2, 1] + Shapes + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` - expanded: [a, b, b, b, c, c, d] - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] + Examples: + - encoder output: :math:`[a,b,c,d]` + - durations: :math:`[1, 3, 2, 1]` + + - expanded: :math:`[a, b, b, b, c, c, d]` + - attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]` """ - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale o_dr[o_dr < 1] = 1.0 o_dr = torch.round(o_dr) @@ -249,22 +306,39 @@ class FastPitch(BaseTTS): g = self.proj_g(g) return x + g - def _forward_encoder(self, x, x_lengths, g=None): + def _forward_encoder( + self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Concat speaker embedding to the encoder output for the duration predictor. + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] - if g is not None: g = g.unsqueeze(-1) - # [B, T, C] x_emb = self.emb(x) - - # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - # encoder pass o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) - # speaker conditioning for duration predictor if g is not None: o_en_dp = self._concat_speaker_embedding(o_en, g) @@ -272,8 +346,33 @@ class FastPitch(BaseTTS): o_en_dp = o_en return o_en, o_en_dp, x_mask, g, x_emb - def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding @@ -286,7 +385,34 @@ class FastPitch(BaseTTS): o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de.transpose(1, 2), attn.transpose(1, 2) - def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ o_pitch = self.pitch_predictor(o_en, x_mask) if pitch is not None: avg_pitch = average_pitch(pitch, dr) @@ -295,49 +421,111 @@ class FastPitch(BaseTTS): o_pitch_emb = self.pitch_emb(o_pitch) return o_pitch_emb, o_pitch - def _forward_aligner(self, y, embedding, x_mask, y_mask): + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) alignment_mas = maximum_path( alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() ) - o_alignment_dur = torch.sum(alignment_mas, -1) - return o_alignment_dur, alignment_logprob, alignment_mas + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas def forward( - self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} - ): # pylint: disable=unused-argument - """ + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - y_lengths: :math:`[B]` - y: :math:`[B, T_max2]` - dr: :math:`[B, T_max]` - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g) - if self.config.model_args.detach_duration_predictor: + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype) + # encoder pass + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) else: o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner pass if self.use_aligner: - o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask) + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) dr = o_alignment_dur + # pitch predictor pass o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_en = o_en + o_pitch_emb - o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + # decoder pass + o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) outputs = { "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), "durations": o_dr.squeeze(1), + "attn_durations": o_attn, # for visualization "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, + "alignment_soft": alignment_soft.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, @@ -346,43 +534,33 @@ class FastPitch(BaseTTS): @torch.no_grad() def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument - """ + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - # input sequence should be greated than the max convolution size - inference_padding = 5 - if x.shape[1] < 13: - inference_padding += 13 - x.shape[1] - # pad input to prevent dropping the last word - x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) - o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) # pitch predictor pass o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) - # if pitch_transform is not None: - # if self.pitch_std[0] == 0.0: - # # XXX LJSpeech-1.1 defaults - # mean, std = 218.14, 67.24 - # else: - # mean, std = self.pitch_mean[0], self.pitch_std[0] - # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) - - # if pitch_tgt is None: - # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) - # else: - # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) o_en = o_en + o_pitch_emb - y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) outputs = { - "model_outputs": o_de.transpose(1, 2), + "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log, @@ -398,33 +576,35 @@ class FastPitch(BaseTTS): d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] durations = batch["durations"] - aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + # forward pass outputs = self.forward( text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input ) - + # use aligner's output as the duration target if self.use_aligner: durations = outputs["o_alignment_dur"] - - with autocast(enabled=False): # use float32 for the criterion + # use float32 in AMP + with autocast(enabled=False): # compute loss loss_dict = criterion( - outputs["model_outputs"], - mel_input, - mel_lengths, - outputs["durations_log"], - durations, - outputs["pitch"], - outputs["pitch_gt"], - text_lengths, - outputs["alignment_logprob"], + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch"], + pitch_target=outputs["pitch_gt"], + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"], + alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None ) - # compute duration error durations_pred = outputs["durations"] duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() loss_dict["duration_error"] = duration_error + return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use @@ -442,9 +622,10 @@ class FastPitch(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } - if self.config.model_args.use_aligner and self.training: - alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy() - figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False) + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) @@ -470,8 +651,20 @@ class FastPitch(BaseTTS): return FastPitchLoss(self.config) + def on_train_step_start(self, trainer): + """Enable binary alignment loss when needed""" + if trainer.total_steps_done > self.config.binary_align_loss_start_step: + self.use_binary_alignment_loss = True + def average_pitch(pitch, durs): + """Compute the average pitch value for each input character based on the durations. + + Shapes: + - pitch: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + """ + durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) diff --git a/docs/source/index.md b/docs/source/index.md index 77d198c0..d842f894 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ models/glow_tts.md models/vits.md + models/fast_pitch.md .. toctree:: :maxdepth: 2 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 63f50dd9..5c9e67da 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -14,10 +14,11 @@ dataset_config = BaseDatasetConfig( # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"), ) + audio_config = BaseAudioConfig( sample_rate=22050, - do_trim_silence=False, - trim_db=0.0, + do_trim_silence=True, + trim_db=60.0, signal_norm=False, mel_fmin=0.0, mel_fmax=8000, @@ -26,6 +27,7 @@ audio_config = BaseAudioConfig( ref_level_db=20, preemphasis=0.0, ) + config = FastPitchConfig( run_name="fast_pitch_ljspeech", audio=audio_config, @@ -33,6 +35,7 @@ config = FastPitchConfig( eval_batch_size=16, num_loader_workers=8, num_eval_loader_workers=4, + compute_input_seq_cache=True, compute_f0=True, f0_cache_path=os.path.join(output_path, "f0_cache"), run_eval=True, @@ -45,6 +48,8 @@ config = FastPitchConfig( print_step=50, print_eval=False, mixed_precision=False, + sort_by_audio_len=True, + max_seq_len=500000, output_path=output_path, datasets=[dataset_config], ) From 46728895492309f2a819bdd27e7607dfa68ae6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:27:56 +0000 Subject: [PATCH 37/52] Update `generic.FFTransformer` --- TTS/tts/layers/generic/transformer.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 9e6b69ac..12f0bbb0 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -15,17 +15,19 @@ class FFTransformer(nn.Module): self.norm1 = nn.LayerNorm(in_out_channels) self.norm2 = nn.LayerNorm(in_out_channels) - self.dropout = nn.Dropout(dropout_p) + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) def forward(self, src, src_mask=None, src_key_padding_mask=None): """😦 ugly looking with all the transposing""" src = src.permute(2, 0, 1) src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) src = self.norm1(src + src2) # T x B x D -> B x D x T src = src.permute(1, 2, 0) src2 = self.conv2(F.relu(self.conv1(src))) - src2 = self.dropout(src2) + src2 = self.dropout2(src2) src = src + src2 src = src.transpose(1, 2) src = self.norm2(src) @@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module): """ TODO: handle multi-speaker Shapes: - x: [B, C, T] - mask: [B, 1, T] or [B, T] + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` """ if mask is not None and mask.ndim == 3: mask = mask.squeeze(1) @@ -65,3 +67,19 @@ class FFTransformerBlock(nn.Module): alignments.append(align.unsqueeze(1)) alignments = torch.cat(alignments, 1) return x + + +class FFTDurationPredictor: + def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x From 076d0cb2583393646abe6de384a14ebf05d84d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:28:26 +0000 Subject: [PATCH 38/52] Add tests for certain FastPitch functions --- tests/tts_tests/test_fast_pitch.py | 51 ++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/tts_tests/test_fast_pitch.py diff --git a/tests/tts_tests/test_fast_pitch.py b/tests/tts_tests/test_fast_pitch.py new file mode 100644 index 00000000..ba6b0ce6 --- /dev/null +++ b/tests/tts_tests/test_fast_pitch.py @@ -0,0 +1,51 @@ +import unittest + +import torch as T + +from TTS.tts.layers.losses import L1LossMasked, SSIMLoss +from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet +from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch +from TTS.tts.utils.data import sequence_mask + +# pylint: disable=unused-variable + + +class AveragePitchTests(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + pitch = T.rand(1, 1, 128) + + durations = T.randint(1, 5, (1, 21)) + coeff = 128.0 / durations.sum() + durations = T.round(durations * coeff) + diff = 128.0 - durations.sum() + durations[0, -1] += diff + durations = durations.long() + + pitch_avg = average_pitch(pitch, durations) + + index = 0 + for idx, dur in enumerate(durations[0]): + assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5 + index += dur + + +def expand_encoder_outputs_test(): + model = FastPitch(FastPitchArgs(num_chars=10)) + + inputs = T.rand(2, 5, 57) + durations = T.randint(1, 4, (2, 57)) + + x_mask = T.ones(2, 1, 57) + y_mask = T.ones(2, 1, durations.sum(1).max()) + + expanded, attn = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) + + for b in range(durations.shape[0]): + index = 0 + for idx, dur in enumerate(durations[b]): + diff = ( + expanded[b, :, index : index + dur.item()] + - inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape) + ).sum() + assert abs(diff) < 1e-6, diff + index += dur From 29248536c90db097716e87a035b1c2dbfcbc5563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:28:46 +0000 Subject: [PATCH 39/52] Update `PositionalEncoding` --- TTS/tts/layers/generic/pos_encoding.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py index 46a0b516..913add0d 100644 --- a/TTS/tts/layers/generic/pos_encoding.py +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -7,17 +7,23 @@ from torch import nn class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" + Args: channels (int): embedding size - dropout (float): dropout parameter + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. """ - def __init__(self, channels, dropout_p=0.0, max_len=5000): + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): super().__init__() if channels % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) ) + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) @@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module): pos_enc = self.pe[:, :, : x.size(2)] * mask else: pos_enc = self.pe[:, :, : x.size(2)] - x = x + pos_enc + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc else: - x = x + self.pe[:, :, first_idx:last_idx] + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] if hasattr(self, "dropout"): x = self.dropout(x) return x From 0b8bc71fc986957837ae6a9e522a0d631a96f13a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 4 Sep 2021 08:36:28 +0000 Subject: [PATCH 40/52] Integrate Scarf pixel --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b448a75..72cd332a 100644 --- a/README.md +++ b/README.md @@ -150,4 +150,6 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht |- (same) |- vocoder/ (Vocoder models.) |- (same) -``` \ No newline at end of file +``` + + \ No newline at end of file From 6878ddfbea46a439ac025afcfc47d49c2c53c4d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:21:47 +0000 Subject: [PATCH 41/52] Update README.md format --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 72cd332a..60e0393f 100644 --- a/README.md +++ b/README.md @@ -4,14 +4,14 @@ 🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects. [![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions) -[![License]()](https://opensource.org/licenses/MPL-2.0) -[![Docs]()](https://tts.readthedocs.io/en/latest/) [![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS) [![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md) [![Downloads](https://pepy.tech/badge/tts)](https://pepy.tech/project/tts) -[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440) +[![Docs]()](https://tts.readthedocs.io/en/latest/) +[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![License]()](https://opensource.org/licenses/MPL-2.0) 📰 [**Subscribe to 🐸Coqui.ai Newsletter**](https://coqui.ai/?subscription=true) From 91a70e80b20da7666e28ca14098e255afa32aa5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:24:06 +0000 Subject: [PATCH 42/52] Refactor TTSDataset Return a dict by `collate` Refactor batch handling in `collate` A couple of bug fixes --- TTS/bin/extract_tts_spectrograms.py | 19 +++--- TTS/trainer.py | 9 ++- TTS/tts/configs/shared_configs.py | 4 +- TTS/tts/datasets/TTSDataset.py | 93 +++++++++++++++++------------ 4 files changed, 74 insertions(+), 51 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 6ec99fac..9f54cb39 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -77,14 +77,14 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data[0] - text_lengths = data[1] - mel_input = data[4] - mel_lengths = data[5] - item_idx = data[7] - d_vectors = data[8] - speaker_ids = data[9] - attn_mask = data[10] + text_input = data['text'] + text_lengths = data['text_lengths'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + item_idx = data['item_idxs'] + d_vectors = data['d_vectors'] + speaker_ids = data['speaker_ids'] + attn_mask = data['attns'] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) @@ -132,9 +132,8 @@ def inference( speaker_c = speaker_ids elif d_vectors is not None: speaker_c = d_vectors - outputs = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c} + text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids} ) model_output = outputs["model_outputs"] model_output = model_output.transpose(1, 2).detach().cpu().numpy() diff --git a/TTS/trainer.py b/TTS/trainer.py index 9bb5b096..bc9a49c6 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -271,8 +271,13 @@ class Trainer: # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) - if self.args.continue_path: - self.scheduler.last_epoch = self.restore_step + if self.scheduler is not None: + if self.args.continue_path: + if isinstance(self.scheduler, list): + for scheduler in self.scheduler: + scheduler.last_epoch = self.restore_step + else: + self.scheduler.last_epoch = self.restore_step # DISTRUBUTED if self.num_gpus > 1: diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 3dc70786..e208c16c 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig): enable / disable masking loss values against padded segments of samples in a batch. sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. min_seq_len (int): Minimum sequence length to be used at training. @@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig): batch_group_size: int = 0 loss_masking: bool = None # dataloading - sort_by_audio_len: bool = True + sort_by_audio_len: bool = False min_seq_len: int = 1 max_seq_len: int = float("inf") compute_f0: bool = False diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 74cb8de1..c81e0e6c 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -9,7 +9,7 @@ import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.utils.audio import AudioProcessor @@ -249,8 +249,8 @@ class TTSDataset(Dataset): pitch = None if self.compute_f0: - pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) - pitch = self.pitch_extractor.normalize_pitch(pitch) + pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) sample = { "raw_text": raw_text, @@ -356,6 +356,11 @@ class TTSDataset(Dataset): temp_items = new_items[offset:end_offset] random.shuffle(temp_items) new_items[offset:end_offset] = temp_items + + if len(new_items) == 0: + raise RuntimeError(" [!] No items left after filtering.") + + # update items to the new sorted items self.items = new_items # logging @@ -376,6 +381,18 @@ class TTSDataset(Dataset): def __getitem__(self, idx): return self.load_data(idx) + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + def collate_fn(self, batch): r""" Perform preprocessing and create a final data batch: @@ -388,30 +405,27 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lenghts = np.array([len(d["text"]) for d in batch]) + text_lengths = np.array([len(d["text"]) for d in batch]) # sort items with text input length for RNN efficiency - text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) + batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) - wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] - item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] - text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing] + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} - speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get pre-computed d-vectors if self.d_vector_mapping is not None: - wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] + wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] else: d_vectors = None # get numerical speaker ids from speaker names if self.speaker_id_mapping: - speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names] + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] else: speaker_ids = None # compute features - mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] @@ -430,7 +444,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(text).astype(np.int32) + text = prepare_data(batch["text"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -439,7 +453,7 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lenghts = torch.LongTensor(text_lenghts) + text_lengths = torch.LongTensor(text_lengths) text = torch.LongTensor(text) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) @@ -453,7 +467,7 @@ class TTSDataset(Dataset): # compute linear spectrogram if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype("float32") for w in wav] + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -464,11 +478,11 @@ class TTSDataset(Dataset): # format waveforms wav_padded = None if self.return_wav: - wav_lengths = [w.shape[0] for w in wav] + wav_lengths = [w.shape[0] for w in batch["wav"]] max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length wav_lengths = torch.LongTensor(wav_lengths) - wav_padded = torch.zeros(len(batch), 1, max_wav_len) - for i, w in enumerate(wav): + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = w[: mel_length * self.ap.hop_length] @@ -477,18 +491,16 @@ class TTSDataset(Dataset): # compute f0 # TODO: compare perf in collate_fn vs in load_data - pitch = None if self.compute_f0: - pitch = [b["pitch"] for b in batch] - pitch = prepare_data(pitch) + pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: pitch = None # collate attention alignments - if batch[0]["attn"] is not None: - attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] @@ -502,18 +514,18 @@ class TTSDataset(Dataset): # TODO: return dictionary return { "text": text, - "text_lengths": text_lenghts, - "speaker_names": speaker_names, + "text_lengths": text_lengths, + "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, "mel_lengths": mel_lengths, "stop_targets": stop_targets, - "item_idxs": item_idxs, + "item_idxs": batch["item_idx"], "d_vectors": d_vectors, "speaker_ids": speaker_ids, "attns": attns, "waveform": wav_padded, - "raw_text": raw_text, + "raw_text": batch["raw_text"], "pitch": pitch, } @@ -567,13 +579,20 @@ class PitchExtractor: def normalize_pitch(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] - pitch -= self.mean - pitch /= self.std + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean pitch[zero_idxs] = 0.0 return pitch @staticmethod - def _load_or_compute_pitch(ap, wav_file, cache_path): + def load_or_compute_pitch(ap, wav_file, cache_path): """ compute pitch and return a numpy array of pitch values """ @@ -582,7 +601,7 @@ class PitchExtractor: pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) - return pitch + return pitch.astype(np.float32) @staticmethod def _pitch_worker(args): @@ -596,7 +615,7 @@ class PitchExtractor: return pitch return None - def compute_pitch(self, cache_path, num_workers=0): + def compute_pitch(self, ap, cache_path, num_workers=0): """Compute the input sequences with multi-processing. Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" if not os.path.exists(cache_path): @@ -607,12 +626,12 @@ class PitchExtractor: if num_workers == 0: pitch_vecs = [] for _, item in enumerate(tqdm.tqdm(self.items)): - pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])] + pitch_vecs += [self._pitch_worker([item, ap, cache_path])] else: with Pool(num_workers) as p: pitch_vecs = list( tqdm.tqdm( - p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]), + p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), total=len(self.items), ) ) @@ -623,5 +642,5 @@ class PitchExtractor: def load_pitch_stats(self, cache_path): stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"] - self.std = stats["std"] + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) From 76c4929ab2b9ed009e7babf1f692a47adf95d32f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:25:07 +0000 Subject: [PATCH 43/52] Fix attn mask reading bug --- TTS/tts/datasets/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 2e315963..c2e55038 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -66,11 +66,11 @@ def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], L if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): - attn_file = meta_data[os.path.abspath(ins[1])].strip() + attn_file = meta_data[ins[1]].strip() meta_data_train_all[idx].append(attn_file) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): - attn_file = meta_data[os.path.abspath(ins[1])].strip() + attn_file = meta_data[ins[1]].strip() meta_data_eval_all[idx].append(attn_file) return meta_data_train_all, meta_data_eval_all From 2b59da802c61777221f93a0c7be4b4dc0cb28b1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:25:45 +0000 Subject: [PATCH 44/52] Fix loader setup in `base_tts` --- TTS/tts/layers/generic/transformer.py | 6 ++++-- TTS/tts/layers/losses.py | 3 ++- TTS/tts/models/base_tts.py | 9 +++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 12f0bbb0..2fe9bcc4 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module): class FFTDurationPredictor: - def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): + def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) self.proj = nn.Linear(in_channels, 1) - def forward(self, x, mask=None, g=None): + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument """ Shapes: - x: :math:`[B, C, T]` - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input """ x = self.fft(x, mask=mask) x = self.proj(x) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 100b8fb3..a2fd7635 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module): self.aligner_loss_alpha = c.aligner_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha - def _binary_alignment_loss(self, alignment_hard, alignment_soft): + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): """Binary loss that forces soft alignments to match the hard alignments as explained in `https://arxiv.org/pdf/2108.10447.pdf`. """ diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 653143cd..06c7cb2b 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -252,13 +252,18 @@ class BaseTTS(BaseModel): # compute pitch frames and write to files. if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): - dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) - dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.pitch_extractor.compute_pitch( + ap, config.get("f0_cache_path", None), config.num_loader_workers + ) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: dist.barrier() + # load pitch stats computed above by all the workers + if config.compute_f0: + dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None From 8d41060d36d76f045568a2b5f0129d80dcd2a25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:26:36 +0000 Subject: [PATCH 45/52] Plot unnormalized pitch by `FastPitch` --- TTS/tts/models/fast_pitch.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 352aebfa..1dd0bd68 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.soft_dtw import SoftDTW @dataclass @@ -232,7 +231,9 @@ class FastPitch(BaseTTS): self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) if self.args.use_aligner: - self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels) + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) @staticmethod def generate_attn(dr, x_mask, y_mask=None): @@ -307,7 +308,7 @@ class FastPitch(BaseTTS): return x + g def _forward_encoder( - self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Encoding forward pass. @@ -522,13 +523,15 @@ class FastPitch(BaseTTS): "durations_log": o_dr_log.squeeze(1), "durations": o_dr.squeeze(1), "attn_durations": o_attn, # for visualization - "pitch": o_pitch, - "pitch_gt": avg_pitch, + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, "alignments": attn, "alignment_soft": alignment_soft.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, } return outputs @@ -577,6 +580,7 @@ class FastPitch(BaseTTS): speaker_ids = batch["speaker_ids"] durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + # forward pass outputs = self.forward( text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input @@ -593,12 +597,12 @@ class FastPitch(BaseTTS): decoder_output_lens=mel_lengths, dur_output=outputs["durations_log"], dur_target=durations, - pitch_output=outputs["pitch"], - pitch_target=outputs["pitch_gt"], + pitch_output=outputs["pitch_avg"], + pitch_target=outputs["pitch_avg_gt"], input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"], alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, - alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, ) # compute duration error durations_pred = outputs["durations"] @@ -611,15 +615,26 @@ class FastPitch(BaseTTS): model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] + pitch = batch["pitch"] + pitch_avg_expanded, _ = self.expand_encoder_outputs( + outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] + ) pred_spec = model_outputs[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() + pitch = pitch[0, 0].data.cpu().numpy() + + # TODO: denormalize before plotting + pitch = abs(pitch) + pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() figures = { "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), + "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), + "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), } # plot the attention mask computed from the predicted durations From d847a68e42536805f3f301555201c001fcf8f055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:27:13 +0000 Subject: [PATCH 46/52] Reformat multi-speaker handling in GlowTTS --- TTS/tts/models/glow_tts.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 27012207..b063b6b4 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -109,6 +109,10 @@ class GlowTTS(BaseTTS): # init speaker manager self.speaker_manager = get_speaker_manager(config, data=data) self.num_speakers = self.speaker_manager.num_speakers + if config.use_d_vector_file: + self.external_d_vector_dim = config.d_vector_dim + else: + self.external_d_vector_dim = 0 # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: self.embedded_speaker_dim = self.c_in_channels @@ -129,7 +133,7 @@ class GlowTTS(BaseTTS): return y_mean, y_log_scale, o_attn_dur def forward( - self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -143,8 +147,8 @@ class GlowTTS(BaseTTS): y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] @@ -181,7 +185,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def inference_with_MAS( - self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. @@ -198,12 +202,11 @@ class GlowTTS(BaseTTS): y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.external_d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -243,7 +246,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def decoder_inference( - self, y, y_lengths=None, aux_input={"d_vectors": None} + self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -275,7 +278,7 @@ class GlowTTS(BaseTTS): return outputs @torch.no_grad() - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value + def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None @@ -326,8 +329,9 @@ class GlowTTS(BaseTTS): mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] - outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors}) + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids}) loss_dict = criterion( outputs["model_outputs"], From c1513ec4cdb7dd6626c0a798b50dbf30cd85c9b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:27:40 +0000 Subject: [PATCH 47/52] Plot pitch over spectrogram --- TTS/tts/models/vits.py | 6 +++--- TTS/tts/utils/visual.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 72c67df2..0da43f90 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 x_lengths = T max_idxs = x_lengths - segment_size + 1 assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - ids_str = (torch.rand([B]).type_as(x) * max_idxs).long() - ret = segment(x, ids_str, segment_size) - return ret, ids_str + segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices @dataclass diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 44732322..7101ed3d 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): return fig +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output, From 2c4bbbf9b9a25dc952eeebed9ebfced86429c992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:29:22 +0000 Subject: [PATCH 48/52] Use pyworld for pitch --- TTS/utils/audio.py | 55 ++------ TTS/utils/yin.py | 118 ------------------ TTS/vocoder/models/__init__.py | 2 +- .../ljspeech/fast_pitch/train_fast_pitch.py | 1 + requirements.txt | 1 + tests/__init__.py | 5 +- 6 files changed, 19 insertions(+), 163 deletions(-) delete mode 100644 TTS/utils/yin.py diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 6a74b3c8..01d1f7d1 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -2,6 +2,7 @@ from typing import Dict, Tuple import librosa import numpy as np +import pyworld as pw import scipy.io.wavfile import scipy.signal import soundfile as sf @@ -9,7 +10,6 @@ import torch from torch import nn from TTS.tts.utils.data import StandardScaler -from TTS.utils.yin import compute_yin class TorchSTFT(nn.Module): # pylint: disable=abstract-method @@ -640,59 +640,28 @@ class AudioProcessor(object): >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> pitch = ap.compute_f0(wav) """ - # f0, t = pw.dio( - # x.astype(np.double), - # fs=self.sample_rate, - # f0_ceil=self.mel_fmax, - # frame_period=1000 * self.hop_length / self.sample_rate, - # ) - # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # f0, _, _, _ = compute_yin( - # x, - # self.sample_rate, - # self.win_length, - # self.hop_length, - # 65 if self.mel_fmin == 0 else self.mel_fmin, - # self.mel_fmax, - # ) - # # import pyworld as pw - # # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, - # # frame_period=self.hop_length / self.sample_rate * 1000) + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.mel_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) # pad = int((self.win_length / self.hop_length) / 2) # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) # f0 = np.array(f0, dtype=np.float32) - f0, _, _ = librosa.pyin( - x, - fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - fmax=self.mel_fmax, - frame_length=self.win_length, - sr=self.sample_rate, - fill_na=0.0, - ) - - # f02 = librosa.yin( + # f01, _, _ = librosa.pyin( # x, # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, # fmax=self.mel_fmax, # frame_length=self.win_length, - # sr=self.sample_rate + # sr=self.sample_rate, + # fill_na=0.0, # ) # spec = self.melspectrogram(x) - - # from matplotlib import pyplot as plt - # plt.figure() - # plt.plot(f0, linewidth=2.5, color='red') - # plt.plot(f01, linewidth=2.5, linestyle='-.') - # plt.plot(f02, linewidth=2.5) - # plt.xlabel('time') - # plt.ylabel('F0') - # plt.savefig('save_img.png') - - # # plt.figure() - # plt.imshow(spec, aspect="auto", origin="lower") - # plt.savefig('save_img2.png') return f0 ### Audio Processing ### diff --git a/TTS/utils/yin.py b/TTS/utils/yin.py deleted file mode 100644 index 3d8bf64b..00000000 --- a/TTS/utils/yin.py +++ /dev/null @@ -1,118 +0,0 @@ -# adapted from https://github.com/patriceguyot/Yin - -import numpy as np - - -def differenceFunction(x, N, tau_max): - """ - Compute difference function of data x. This corresponds to equation (6) in [1] - This solution is implemented directly with Numpy fft. - - - :param x: audio data - :param N: length of data - :param tau_max: integration window size - :return: difference function - :rtype: list - """ - - x = np.array(x, np.float64) - w = x.size - tau_max = min(tau_max, w) - x_cumsum = np.concatenate((np.array([0.0]), (x * x).cumsum())) - size = w + tau_max - p2 = (size // 32).bit_length() - nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) - size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size) - fc = np.fft.rfft(x, size_pad) - conv = np.fft.irfft(fc * fc.conjugate())[:tau_max] - return x_cumsum[w : w - tau_max : -1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv - - -def cumulativeMeanNormalizedDifferenceFunction(df, N): - """ - Compute cumulative mean normalized difference function (CMND). - - This corresponds to equation (8) in [1] - - :param df: Difference function - :param N: length of data - :return: cumulative mean normalized difference function - :rtype: list - """ - - cmndf = df[1:] * range(1, N) / np.cumsum(df[1:]).astype(float) # scipy method - return np.insert(cmndf, 0, 1) - - -def getPitch(cmdf, tau_min, tau_max, harmo_th=0.1): - """ - Return fundamental period of a frame based on CMND function. - - :param cmdf: Cumulative Mean Normalized Difference function - :param tau_min: minimum period for speech - :param tau_max: maximum period for speech - :param harmo_th: harmonicity threshold to determine if it is necessary to compute pitch frequency - :return: fundamental period if there is values under threshold, 0 otherwise - :rtype: float - """ - tau = tau_min - while tau < tau_max: - if cmdf[tau] < harmo_th: - while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]: - tau += 1 - return tau - tau += 1 - - return 0 # if unvoiced - - -def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500, harmo_thresh=0.1): - """ - - Compute the Yin Algorithm. Return fundamental frequency and harmonic rate. - - :param sig: Audio signal (list of float) - :param sr: sampling rate (int) - :param w_len: size of the analysis window (samples) - :param w_step: size of the lag between two consecutives windows (samples) - :param f0_min: Minimum fundamental frequency that can be detected (hertz) - :param f0_max: Maximum fundamental frequency that can be detected (hertz) - :param harmo_tresh: Threshold of detection. The yalgorithmù return the first minimum of the CMND function below this treshold. - - :returns: - - * pitches: list of fundamental frequencies, - * harmonic_rates: list of harmonic rate values for each fundamental frequency value (= confidence value) - * argmins: minimums of the Cumulative Mean Normalized DifferenceFunction - * times: list of time of each estimation - :rtype: tuple - """ - - tau_min = int(sr / f0_max) - tau_max = int(sr / f0_min) - - timeScale = range(0, len(sig) - w_len, w_step) # time values for each analysis window - times = [t / float(sr) for t in timeScale] - frames = [sig[t : t + w_len] for t in timeScale] - - pitches = [0.0] * len(timeScale) - harmonic_rates = [0.0] * len(timeScale) - argmins = [0.0] * len(timeScale) - - for i, frame in enumerate(frames): - # Compute YIN - df = differenceFunction(frame, w_len, tau_max) - cmdf = cumulativeMeanNormalizedDifferenceFunction(df, tau_max) - p = getPitch(cmdf, tau_min, tau_max, harmo_thresh) - - # Get results - if np.argmin(cmdf) > tau_min: - argmins[i] = float(sr / np.argmin(cmdf)) - if p != 0: # A pitch was found - pitches[i] = float(sr / p) - harmonic_rates[i] = cmdf[p] - else: # No pitch, but we compute a value of the harmonic rate - harmonic_rates[i] = min(cmdf) - - return pitches, harmonic_rates, argmins, times diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index edc94d72..a70ebe40 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -11,7 +11,6 @@ def to_camel(text): def setup_model(config: Coqpit): """Load models directly from configuration.""" - print(" > Vocoder Model: {}".format(config.model)) if "discriminator_model" in config and "generator_model" in config: MyModel = importlib.import_module("TTS.vocoder.models.gan") MyModel = getattr(MyModel, "GAN") @@ -28,6 +27,7 @@ def setup_model(config: Coqpit): MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e + print(" > Vocoder Model: {}".format(config.model)) model = MyModel(config) return model diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 5c9e67da..614e42e0 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -43,6 +43,7 @@ config = FastPitchConfig( epochs=1000, text_cleaner="english_cleaners", use_phonemes=True, + use_espeak_phonemes=False, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=50, diff --git a/requirements.txt b/requirements.txt index b92947a0..a87a3c6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 fsspec>=2021.04.0 +pyworld \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index a7878132..2b07004f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,7 +7,10 @@ from TTS.utils.generic_utils import get_cuda def get_device_id(): use_cuda, _ = get_cuda() if use_cuda: - GPU_ID = "0" + if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "": + GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0] + else: + GPU_ID = "0" else: GPU_ID = "" return GPU_ID From fd287aa4389d4a3e14f01e924914f2a3d4bc0208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:29:45 +0000 Subject: [PATCH 49/52] Update loader tests for dict return --- tests/data_tests/test_loader.py | 66 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 717b2e0f..0fbb6bde 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] - wavs = data[11] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] + wavs = data['waveform'] neg_values = text_input[text_input < 0] check_count = len(neg_values) @@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length @@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] # check mel_spec consistency wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32) @@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] if mel_lengths[0] > mel_lengths[1]: idx = 0 From e72c265cd4233caf412e88b07c9cec109fb3c553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:30:15 +0000 Subject: [PATCH 50/52] Fix linter issues --- tests/test_audio_processor.py | 2 +- tests/tts_tests/test_fast_pitch.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_audio_processor.py b/tests/test_audio_processor.py index d3414286..56611692 100644 --- a/tests/test_audio_processor.py +++ b/tests/test_audio_processor.py @@ -182,7 +182,7 @@ class TestAudio(unittest.TestCase): mel_denorm = ap.denormalize(mel_norm) assert abs(mel_reference - mel_denorm).max() < 1e-4 - def test_compute_f0(self): + def test_compute_f0(self): # pylint: disable=no-self-use ap = AudioProcessor(**conf) wav = ap.load_wav(WAV_FILE) pitch = ap.compute_f0(wav) diff --git a/tests/tts_tests/test_fast_pitch.py b/tests/tts_tests/test_fast_pitch.py index ba6b0ce6..1975435e 100644 --- a/tests/tts_tests/test_fast_pitch.py +++ b/tests/tts_tests/test_fast_pitch.py @@ -2,11 +2,7 @@ import unittest import torch as T -from TTS.tts.layers.losses import L1LossMasked, SSIMLoss -from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch -from TTS.tts.utils.data import sequence_mask - # pylint: disable=unused-variable @@ -38,7 +34,7 @@ def expand_encoder_outputs_test(): x_mask = T.ones(2, 1, 57) y_mask = T.ones(2, 1, durations.sum(1).max()) - expanded, attn = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) + expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) for b in range(durations.shape[0]): index = 0 From 4cc544bc461344e36254c8aeae003bee195f6099 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 16:59:22 +0000 Subject: [PATCH 51/52] Add FastPitch model to `.models.json` --- TTS/.models.json | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/TTS/.models.json b/TTS/.models.json index d3c56b94..6f763840 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -62,7 +62,16 @@ "default_vocoder": null, "commit": "3900448", "author": "Eren Gölge @erogol", - "license": "", + "license": "TBD", + "contact": "egolge@coqui.com" + }, + "fast_pitch": { + "description": "FastPitch model trained on LJSpeech using the Aligner Network", + "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.2/tts_models--en--ljspeech--fast_pitch.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "b27b3ba", + "author": "Eren Gölge @erogol", + "license": "TBD", "contact": "egolge@coqui.com" } }, From 82598f3fdbb73893e80904cbc23a0cef0937d4d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 16:59:41 +0000 Subject: [PATCH 52/52] Bump up to v0.2.2 --- TTS/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/VERSION b/TTS/VERSION index 7dff5b89..f4778493 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.2.1 \ No newline at end of file +0.2.2 \ No newline at end of file