diff --git a/README.md b/README.md index 9b448a75..60e0393f 100644 --- a/README.md +++ b/README.md @@ -4,14 +4,14 @@ 🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects. [![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions) -[![License]()](https://opensource.org/licenses/MPL-2.0) -[![Docs]()](https://tts.readthedocs.io/en/latest/) [![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS) [![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md) [![Downloads](https://pepy.tech/badge/tts)](https://pepy.tech/project/tts) -[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440) +[![Docs]()](https://tts.readthedocs.io/en/latest/) +[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![License]()](https://opensource.org/licenses/MPL-2.0) 📰 [**Subscribe to 🐸Coqui.ai Newsletter**](https://coqui.ai/?subscription=true) @@ -150,4 +150,6 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht |- (same) |- vocoder/ (Vocoder models.) |- (same) -``` \ No newline at end of file +``` + + \ No newline at end of file diff --git a/TTS/.models.json b/TTS/.models.json index d3c56b94..6f763840 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -62,7 +62,16 @@ "default_vocoder": null, "commit": "3900448", "author": "Eren Gölge @erogol", - "license": "", + "license": "TBD", + "contact": "egolge@coqui.com" + }, + "fast_pitch": { + "description": "FastPitch model trained on LJSpeech using the Aligner Network", + "github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.2/tts_models--en--ljspeech--fast_pitch.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "b27b3ba", + "author": "Eren Gölge @erogol", + "license": "TBD", "contact": "egolge@coqui.com" } }, diff --git a/TTS/VERSION b/TTS/VERSION index 7dff5b89..f4778493 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.2.1 \ No newline at end of file +0.2.2 \ No newline at end of file diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 3a5c067e..fc8c6629 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -8,12 +8,12 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm +from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config +from TTS.utils.io import load_checkpoint if __name__ == "__main__": # pylint: disable=bad-option-value @@ -27,7 +27,7 @@ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json - --dataset_metafile /root/LJSpeech-1.1/metadata.csv + --dataset_metafile metadata.csv --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech @@ -76,8 +76,7 @@ Example run: num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) - model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) - model.eval() + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") @@ -127,9 +126,9 @@ Example run: mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) + model_outputs = model.forward(text_input, text_lengths, mel_input) - alignments = alignments.detach() + alignments = model_outputs["alignments"].detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 @@ -149,10 +148,12 @@ Example run: alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) - align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output - file_paths.append([item_idx, file_path]) + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) np.save(file_path, alignment) # ourput metafile diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 6ec99fac..9f54cb39 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -77,14 +77,14 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data[0] - text_lengths = data[1] - mel_input = data[4] - mel_lengths = data[5] - item_idx = data[7] - d_vectors = data[8] - speaker_ids = data[9] - attn_mask = data[10] + text_input = data['text'] + text_lengths = data['text_lengths'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + item_idx = data['item_idxs'] + d_vectors = data['d_vectors'] + speaker_ids = data['speaker_ids'] + attn_mask = data['attns'] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) @@ -132,9 +132,8 @@ def inference( speaker_c = speaker_ids elif d_vectors is not None: speaker_c = d_vectors - outputs = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c} + text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids} ) model_output = outputs["model_outputs"] model_output = model_output.transpose(1, 2).detach().cpu().numpy() diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index af054346..d91bf2b6 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -12,60 +12,89 @@ class BaseAudioConfig(Coqpit): Args: fft_size (int): Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + win_length (int): Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match ```fft_size```. Defaults to 1024. + hop_length (int): Number of audio samples between adjacent STFT columns. Defaults to 1024. + frame_shift_ms (int): Set ```hop_length``` based on milliseconds and sampling rate. + frame_length_ms (int): Set ```win_length``` based on milliseconds and sampling rate. + stft_pad_mode (str): Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + sample_rate (int): Audio sampling rate. Defaults to 22050. + resample (bool): Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + preemphasis (float): Preemphasis coefficient. Defaults to 0.0. + ref_level_db (int): 20 Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. Defaults to 20. + do_sound_norm (bool): Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + do_trim_silence (bool): Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + do_amp_to_db_linear (bool, optional): enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + trim_db (int): Silence threshold used for silence trimming. Defaults to 45. + power (float): Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the artifacts in the synthesized voice. Defaults to 1.5. + griffin_lim_iters (int): Number of Griffing Lim iterations. Defaults to 60. + num_mels (int): Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. It needs to be adjusted for a dataset. Defaults to 0. + mel_fmax (float): Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + spec_gain (int): Gain applied when converting amplitude to DB. Defaults to 20. + signal_norm (bool): enable/disable signal normalization. Defaults to True. + min_level_db (int): minimum db threshold for the computed melspectrograms. Defaults to -100. + symmetric_norm (bool): enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to True. + max_norm (float): ```k``` defining the normalization range. Defaults to 4.0. + clip_norm (bool): enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + stats_path (str): Path to the computed stats file. Defaults to None. """ @@ -147,15 +176,20 @@ class BaseDatasetConfig(Coqpit): Args: name (str): Dataset name that defines the preprocessor in use. Defaults to None. + path (str): Root path to the dataset files. Defaults to None. + meta_file_train (str): Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. Defaults to None. + unused_speakers (List): List of speakers IDs that are not used at the training. Default None. + meta_file_val (str): Name of the dataset meta file that defines the instances used at validation. + meta_file_attn_mask (str): Path to the file that lists the attention mask files used with models that require attention masks to train the duration predictor. @@ -298,7 +332,7 @@ class BaseTrainingConfig(Coqpit): keep_all_best: bool = False keep_after: int = 10000 # dataloading - num_loader_workers: int = None + num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False # paths diff --git a/TTS/trainer.py b/TTS/trainer.py index 68b45fe2..bc9a49c6 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -271,6 +271,14 @@ class Trainer: # setup scheduler self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) + if self.scheduler is not None: + if self.args.continue_path: + if isinstance(self.scheduler, list): + for scheduler in self.scheduler: + scheduler.last_epoch = self.restore_step + else: + self.scheduler.last_epoch = self.restore_step + # DISTRUBUTED if self.num_gpus > 1: self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) @@ -291,7 +299,6 @@ class Trainer: Returns: nn.Module: initialized model. """ - # TODO: better model setup try: model = setup_vocoder_model(config) except ModuleNotFoundError: diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 00000000..873f298e --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.fast_pitch import FastPitchArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models. + + Example: + + >>> from TTS.tts.configs import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + huber_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_pitch" + # model specific params + model_args: FastPitchArgs = field(default_factory=FastPitchArgs) + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 8511b1bc..e208c16c 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -103,6 +103,7 @@ class BaseTTSConfig(BaseTrainingConfig): """Shared parameters among all the tts models. Args: + audio (BaseAudioConfig): Audio processor config object instance. @@ -140,11 +141,14 @@ class BaseTTSConfig(BaseTrainingConfig): loss_masking (bool): enable / disable masking loss values against padded segments of samples in a batch. + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. + min_seq_len (int): - Minimum input sequence length to be used at training. + Minimum sequence length to be used at training. max_seq_len (int): - Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + Maximum sequence length to be used at training. Larger values result in more VRAM usage. compute_f0 (int): (Not in use yet). @@ -197,6 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig): batch_group_size: int = 0 loss_masking: bool = None # dataloading + sort_by_audio_len: bool = False min_seq_len: int = 1 max_seq_len: int = float("inf") compute_f0: bool = False diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 58fc66ee..39479231 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -67,11 +67,14 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. + min_seq_len (int): - Minimum text length to be considered for training. Defaults to `13`. + Minimum sequnce length to be considered for training. Defaults to `0`. max_seq_len (int): - Maximum text length to be considered for training. Defaults to `500`. + Maximum sequnce length to be considered for training. Defaults to `500000`. r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 5d38243e..c81e0e6c 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -22,6 +22,8 @@ class TTSDataset(Dataset): compute_linear_spec: bool, ap: AudioProcessor, meta_data: List[List], + compute_f0: bool = False, + f0_cache_path: str = None, characters: Dict = None, custom_symbols: List = None, add_blank: bool = False, @@ -40,8 +42,7 @@ class TTSDataset(Dataset): ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - If you need something different, you can either override or create a new class as the dataset is - initialized by the model. + If you need something different, you can inherit and override. Args: outputs_per_step (int): Number of time frames predicted per step. @@ -54,6 +55,10 @@ class TTSDataset(Dataset): meta_data (list): List of dataset instances. + compute_f0 (bool): compute f0 if True. Defaults to False. + + f0_cache_path (str): Path to store f0 cache. Defaults to None. + characters (dict): `dict` of custom text characters used for converting texts to sequences. custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own @@ -78,8 +83,8 @@ class TTSDataset(Dataset): use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. - phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in - the coming iterations. Defaults to None. + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. @@ -103,6 +108,8 @@ class TTSDataset(Dataset): self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav + self.compute_f0 = compute_f0 + self.f0_cache_path = f0_cache_path self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -119,8 +126,12 @@ class TTSDataset(Dataset): self.verbose = verbose self.input_seq_computed = False self.rescue_item_idx = 1 + self.pitch_computed = False + if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) + if compute_f0: + self.pitch_extractor = PitchExtractor(self.items, verbose=verbose) if self.verbose: print("\n > DataLoader initialization") print(" | > Use phonemes: {}".format(self.use_phonemes)) @@ -236,10 +247,16 @@ class TTSDataset(Dataset): # TODO: find a better fix return self.load_data(self.rescue_item_idx) + pitch = None + if self.compute_f0: + pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) + pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) + sample = { "raw_text": raw_text, "text": text, "wav": wav, + "pitch": pitch, "attn": attn, "item_idx": self.items[idx][1], "speaker_name": speaker_name, @@ -256,8 +273,8 @@ class TTSDataset(Dataset): return phonemes def compute_input_seq(self, num_workers=0): - """compute input sequences separately. Call it before - passing dataset to data loader.""" + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" if not self.use_phonemes: if self.verbose: print(" | > Computing input sequences ...") @@ -339,6 +356,11 @@ class TTSDataset(Dataset): temp_items = new_items[offset:end_offset] random.shuffle(temp_items) new_items[offset:end_offset] = temp_items + + if len(new_items) == 0: + raise RuntimeError(" [!] No items left after filtering.") + + # update items to the new sorted items self.items = new_items # logging @@ -359,11 +381,23 @@ class TTSDataset(Dataset): def __getitem__(self, idx): return self.load_data(idx) + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + def collate_fn(self, batch): r""" Perform preprocessing and create a final data batch: 1. Sort batch instances by text-length - 2. Convert Audio signal to Spectrograms. + 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ @@ -371,30 +405,27 @@ class TTSDataset(Dataset): # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): - text_lenghts = np.array([len(d["text"]) for d in batch]) + text_lengths = np.array([len(d["text"]) for d in batch]) # sort items with text input length for RNN efficiency - text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) + batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths) - wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] - item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] - text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing] + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} - speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get pre-computed d-vectors if self.d_vector_mapping is not None: - wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] + wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] else: d_vectors = None # get numerical speaker ids from speaker names if self.speaker_id_mapping: - speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names] + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] else: speaker_ids = None # compute features - mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] @@ -413,7 +444,7 @@ class TTSDataset(Dataset): stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch - text = prepare_data(text).astype(np.int32) + text = prepare_data(batch["text"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) @@ -422,7 +453,7 @@ class TTSDataset(Dataset): mel = mel.transpose(0, 2, 1) # convert things to pytorch - text_lenghts = torch.LongTensor(text_lenghts) + text_lengths = torch.LongTensor(text_lengths) text = torch.LongTensor(text) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) @@ -436,7 +467,7 @@ class TTSDataset(Dataset): # compute linear spectrogram if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype("float32") for w in wav] + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -447,23 +478,33 @@ class TTSDataset(Dataset): # format waveforms wav_padded = None if self.return_wav: - wav_lengths = [w.shape[0] for w in wav] + wav_lengths = [w.shape[0] for w in batch["wav"]] max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length wav_lengths = torch.LongTensor(wav_lengths) - wav_padded = torch.zeros(len(batch), 1, max_wav_len) - for i, w in enumerate(wav): + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = w[: mel_length * self.ap.hop_length] wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) + # compute f0 + # TODO: compare perf in collate_fn vs in load_data + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None + # collate attention alignments - if batch[0]["attn"] is not None: - attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) @@ -471,21 +512,22 @@ class TTSDataset(Dataset): else: attns = None # TODO: return dictionary - return ( - text, - text_lenghts, - speaker_names, - linear, - mel, - mel_lengths, - stop_targets, - item_idxs, - d_vectors, - speaker_ids, - attns, - wav_padded, - raw_text, - ) + return { + "text": text, + "text_lengths": text_lengths, + "speaker_names": batch["speaker_name"], + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": batch["item_idx"], + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": batch["raw_text"], + "pitch": pitch, + } raise TypeError( ( @@ -495,3 +537,110 @@ class TTSDataset(Dataset): ) ) ) + + +class PitchExtractor: + """Pitch Extractor for computing F0 from wav files. + + Args: + items (List[List]): Dataset samples. + verbose (bool): Whether to print the progress. + """ + + def __init__( + self, + items: List[List], + verbose=False, + ): + self.items = items + self.verbose = verbose + self.mean = None + self.std = None + + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + @staticmethod + def load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None + + def compute_pitch(self, ap, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, ap, cache_path])] + else: + with Pool(num_workers) as p: + pitch_vecs = list( + tqdm.tqdm( + p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index a2520751..c2e55038 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,6 +1,7 @@ import sys from collections import Counter from pathlib import Path +from typing import Dict, List, Tuple import numpy as np @@ -30,7 +31,17 @@ def split_dataset(items): return items[:eval_split_size], items[eval_split_size:] -def load_meta_data(datasets, eval_split=True): +def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]: + """Parse the dataset, load the samples as a list and load the attention alignments if provided. + + Args: + datasets (List[Dict]): A list of dataset dictionaries or dataset configs. + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None for dataset in datasets: @@ -51,7 +62,7 @@ def load_meta_data(datasets, eval_split=True): meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train - # load attention masks for duration predictor training + # load attention masks for the duration predictor training if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py new file mode 100644 index 00000000..eef4c4b6 --- /dev/null +++ b/TTS/tts/layers/generic/aligner.py @@ -0,0 +1,81 @@ +from typing import Tuple + +import torch +from torch import nn + + +class AlignmentNetwork(torch.nn.Module): + """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. + + :: + + query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment + key -> conv1d -> relu -> conv1d -----------------------^ + + Args: + in_query_channels (int): Number of channels in the query network. Defaults to 80. + in_key_channels (int): Number of channels in the key network. Defaults to 512. + attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. + temperature (float): Temperature for the softmax. Defaults to 0.0005. + """ + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_layer = nn.Sequential( + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.query_layer = nn.Sequential( + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward pass of the aligner encoder. + Shapes: + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` + Output: + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. + """ + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) + if attn_prior is not None: + attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + if mask is not None: + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) + return attn, attn_logp diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py index 46a0b516..913add0d 100644 --- a/TTS/tts/layers/generic/pos_encoding.py +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -7,17 +7,23 @@ from torch import nn class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" + Args: channels (int): embedding size - dropout (float): dropout parameter + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. """ - def __init__(self, channels, dropout_p=0.0, max_len=5000): + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): super().__init__() if channels % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) ) + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) @@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module): pos_enc = self.pe[:, :, : x.size(2)] * mask else: pos_enc = self.pe[:, :, : x.size(2)] - x = x + pos_enc + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc else: - x = x + self.pe[:, :, first_idx:last_idx] + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] if hasattr(self, "dropout"): x = self.dropout(x) return x diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 9e6b69ac..2fe9bcc4 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -15,17 +15,19 @@ class FFTransformer(nn.Module): self.norm1 = nn.LayerNorm(in_out_channels) self.norm2 = nn.LayerNorm(in_out_channels) - self.dropout = nn.Dropout(dropout_p) + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) def forward(self, src, src_mask=None, src_key_padding_mask=None): """😦 ugly looking with all the transposing""" src = src.permute(2, 0, 1) src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) src = self.norm1(src + src2) # T x B x D -> B x D x T src = src.permute(1, 2, 0) src2 = self.conv2(F.relu(self.conv1(src))) - src2 = self.dropout(src2) + src2 = self.dropout2(src2) src = src + src2 src = src.transpose(1, 2) src = self.norm2(src) @@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module): """ TODO: handle multi-speaker Shapes: - x: [B, C, T] - mask: [B, 1, T] or [B, T] + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` """ if mask is not None and mask.ndim == 3: mask = mask.squeeze(1) @@ -65,3 +67,21 @@ class FFTransformerBlock(nn.Module): alignments.append(align.unsqueeze(1)) alignments = torch.cat(alignments, 1) return x + + +class FFTDurationPredictor: + def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 5cbfd8fc..ee058095 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -21,8 +21,10 @@ def convert_pad_shape(pad_shape): def generate_path(duration, mask): """ - duration: [b, t_x] - mask: [b, t_x, t_y] + Shapes: + - duration: :math:`[B, T_en]` + - mask: :math:'[B, T_en, T_de]` + - path: :math:`[B, T_en, T_de]` """ device = duration.device b, t_x, t_y = mask.shape @@ -45,8 +47,9 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): """Cython optimised version. - value: [b, t_x, t_y] - mask: [b, t_x, t_y] + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` """ value = value * mask device = value.device diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0ce4ada9..a2fd7635 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -69,9 +69,9 @@ class MSELossMasked(nn.Module): length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Shapes: - x: B x T X D - target: B x T x D - length: B + - x: :math:`[B, T, D]` + - target: :math:`[B, T, D]` + - length: :math:`B` Returns: loss: An average loss value in range [0, 1] masked by the length. """ @@ -658,3 +658,109 @@ class VitsDiscriminatorLoss(nn.Module): loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss return return_dict + + +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + +class FastPitchLoss(nn.Module): + def __init__(self, c): + super().__init__() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + self.pitch_loss = MSELossMasked(False) + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() + + self.spec_loss_alpha = c.spec_loss_alpha + self.ssim_loss_alpha = c.ssim_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.pitch_loss_alpha = c.pitch_loss_alpha + self.aligner_loss_alpha = c.aligner_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, + ): + loss = 0 + return_dict = {} + if self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss + + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss = loss + self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss = loss + self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + if self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss = loss + self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + return_dict["loss"] = loss + return return_dict diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2aa84cb2..2c3bed3d 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -13,7 +13,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec @@ -355,9 +354,6 @@ class AlignTTS(BaseTTS): phase=self.phase, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log( diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index 66842305..01291775 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -78,7 +78,9 @@ class BaseTacotron(BaseTTS): @staticmethod def _format_aux_input(aux_input: Dict) -> Dict: - return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None ############################# # INIT FUNCTIONS diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 922761cb..06c7cb2b 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -104,18 +104,19 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch[0] - text_lengths = batch[1] - speaker_names = batch[2] - linear_input = batch[3] - mel_input = batch[4] - mel_lengths = batch[5] - stop_targets = batch[6] - item_idx = batch[7] - d_vectors = batch[8] - speaker_ids = batch[9] - attn_mask = batch[10] - waveform = batch[11] + text_input = batch["text"] + text_lengths = batch["text_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -162,6 +163,7 @@ class BaseTTS(BaseModel): "max_spec_length": float(max_spec_length), "item_idx": item_idx, "waveform": waveform, + "pitch": pitch, } def get_data_loader( @@ -199,6 +201,8 @@ class BaseTTS(BaseModel): outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, ap=ap, characters=config.characters, @@ -245,6 +249,21 @@ class BaseTTS(BaseModel): # sort input sequences from short to long dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) + # compute pitch frames and write to files. + if config.compute_f0 and rank in [None, 0]: + if not os.path.exists(config.f0_cache_path): + dataset.pitch_extractor.compute_pitch( + ap, config.get("f0_cache_path", None), config.num_loader_workers + ) + + # halt DDP processes for the main process to finish computing the F0 cache + if num_gpus > 1: + dist.barrier() + + # load pitch stats computed above by all the workers + if config.compute_f0: + dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py new file mode 100644 index 00000000..1dd0bd68 --- /dev/null +++ b/TTS/tts/models/fast_pitch.py @@ -0,0 +1,697 @@ +from dataclasses import dataclass, field +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +@dataclass +class FastPitchArgs(Coqpit): + """Fast Pitch Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + use_d_vetor (bool): + Whether to use precomputed d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of channels of the d-vectors. Defaults to 0. + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + use_aligner (bool): + Use aligner network to learn the text to speech alignment. Defaults to True. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + num_speakers: int = 0 + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + use_d_vector: bool = False + d_vector_dim: int = 0 + detach_duration_predictor: bool = False + max_duration: int = 75 + use_aligner: bool = True + + +class FastPitch(BaseTTS): + """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. + + Paper:: + https://arxiv.org/abs/2006.06873 + + Paper abstract:: + We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental + frequency contours. The model predicts pitch contours during inference. By altering these predictions, + the generated speech can be more expressive, better match the semantic of the utterance, and in the end + more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech + that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall + quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead, + and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time + factor for mel-spectrogram synthesis of a typical utterance." + + Args: + config (Coqpit): Model coqpit class. + + Examples: + >>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs + >>> config = FastPitchArgs() + >>> model = FastPitch(config) + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + + super().__init__() + + # don't use isintance not to import recursively + if config.__class__.__name__ == "FastPitchConfig": + if "characters" in config: + # loading from FasrPitchConfig + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + # loading from FastPitchArgs + self.config = config + self.args = config.model_args + elif isinstance(config, FastPitchArgs): + self.args = config + self.config = config + else: + raise ValueError("config must be either a VitsConfig or Vitsself.args") + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_binary_alignment_loss = False + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.args.d_vector_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels + self.args.d_vector_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels + self.args.d_vector_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.num_speakers > 1 and not self.args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples: + - encoder output: :math:`[a,b,c,d]` + - durations: :math:`[1, 3, 2, 1]` + + - expanded: :math:`[a, b, b, b, c, c, d]` + - attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]` + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Concat speaker embedding to the encoder output for the duration predictor. + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_pitch(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype) + # encoder pass + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner pass + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + dr = o_alignment_dur + # pitch predictor pass + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de, + "durations_log": o_dr_log.squeeze(1), + "durations": o_dr.squeeze(1), + "attn_durations": o_attn, # for visualization + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "alignments": attn, + "alignment_soft": alignment_soft.transpose(1, 2), + "alignment_mas": alignment_mas.transpose(1, 2), + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + # pitch predictor pass + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"], + pitch_target=outputs["pitch_avg_gt"], + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"], + alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + pitch = batch["pitch"] + pitch_avg_expanded, _ = self.expand_encoder_outputs( + outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] + ) + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + pitch = pitch[0, 0].data.cpu().numpy() + + # TODO: denormalize before plotting + pitch = abs(pitch) + pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), + "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + } + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel + + return FastPitchLoss(self.config) + + def on_train_step_start(self, trainer): + """Enable binary alignment loss when needed""" + if trainer.total_steps_done > self.config.binary_align_loss_start_step: + self.use_binary_alignment_loss = True + + +def average_pitch(pitch, durs): + """Compute the average pitch value for each input character based on the durations. + + Shapes: + - pitch: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + """ + + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) + pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = pitch.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float() + pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float() + + pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems) + return pitch_avg diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 92c42fa7..b063b6b4 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -10,7 +10,6 @@ from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -110,6 +109,10 @@ class GlowTTS(BaseTTS): # init speaker manager self.speaker_manager = get_speaker_manager(config, data=data) self.num_speakers = self.speaker_manager.num_speakers + if config.use_d_vector_file: + self.external_d_vector_dim = config.d_vector_dim + else: + self.external_d_vector_dim = 0 # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: self.embedded_speaker_dim = self.c_in_channels @@ -130,22 +133,22 @@ class GlowTTS(BaseTTS): return y_mean, y_log_scale, o_attn_dur def forward( - self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: - x: :math:`[B, T]` - - x_lenghts::math:` B` + - x_lenghts::math:`B` - y: :math:`[B, T, C]` - - y_lengths::math:` B` + - y_lengths::math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] @@ -182,7 +185,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def inference_with_MAS( - self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. @@ -199,12 +202,11 @@ class GlowTTS(BaseTTS): y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.external_d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -244,7 +246,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def decoder_inference( - self, y, y_lengths=None, aux_input={"d_vectors": None} + self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -276,7 +278,7 @@ class GlowTTS(BaseTTS): return outputs @torch.no_grad() - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value + def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None @@ -327,8 +329,9 @@ class GlowTTS(BaseTTS): mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] - outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors}) + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids}) loss_dict = criterion( outputs["model_outputs"], @@ -341,9 +344,6 @@ class GlowTTS(BaseTTS): text_lengths, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 72c67df2..0da43f90 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4 x_lengths = T max_idxs = x_lengths - segment_size + 1 assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size." - ids_str = (torch.rand([B]).type_as(x) * max_idxs).long() - ret = segment(x, ids_str, segment_size) - return ret, ids_str + segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices @dataclass diff --git a/TTS/tts/utils/measures.py b/TTS/tts/utils/measures.py index fdd31242..90e862e1 100644 --- a/TTS/tts/utils/measures.py +++ b/TTS/tts/utils/measures.py @@ -7,7 +7,7 @@ def alignment_diagonal_score(alignments, binary=False): binary (bool): if True, ignore scores and consider attention as a binary mask. Shape: - alignments : batch x decoder_steps x encoder_steps + - alignments : :math:`[B, T_de, T_en]` """ maxs = alignments.max(dim=1)[0] if binary: diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 20712f1d..66f518b4 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -45,12 +45,10 @@ def text2phone(text, language, use_espeak_phonemes=False): # TO REVIEW : How to have a good implementation for this? if language == "zh-CN": ph = chinese_text_to_phonemes(text) - print(" > Phonemes: {}".format(ph)) return ph if language == "ja-jp": ph = japanese_text_to_phonemes(text) - print(" > Phonemes: {}".format(ph)) return ph if gruut.is_language_supported(language): @@ -80,7 +78,6 @@ def text2phone(text, language, use_espeak_phonemes=False): # Fix a few phonemes ph = ph.translate(GRUUT_TRANS_TABLE) - return ph raise ValueError(f" [!] Language {language} is not supported for phonemization.") diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py index a4629a30..969becfd 100644 --- a/TTS/tts/utils/text/japanese/phonemizer.py +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -2,8 +2,10 @@ # compatible with Julius https://github.com/julius-speech/segmentation-kit import re +import unicodedata import MeCab +from num2words import num2words _CONVRULES = [ # Conversion of 2 letters @@ -373,8 +375,93 @@ def text2kata(text: str) -> str: return hira2kata("".join(res)) +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + def japanese_text_to_phonemes(text: str) -> str: """Convert Japanese text to phonemes.""" - res = text2kata(text) + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) res = kata2phoneme(res) return res.replace(" ", "") diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 44732322..7101ed3d 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): return fig +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output, diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 0a343fbf..01d1f7d1 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -2,6 +2,7 @@ from typing import Dict, Tuple import librosa import numpy as np +import pyworld as pw import scipy.io.wavfile import scipy.signal import soundfile as sf @@ -10,8 +11,6 @@ from torch import nn from TTS.tts.utils.data import StandardScaler -# import pyworld as pw - class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. @@ -623,17 +622,47 @@ class AudioProcessor(object): return 0, pad return pad // 2, pad // 2 + pad % 2 - ### Compute F0 ### - # TODO: pw causes some dep issues - # def compute_f0(self, x): - # f0, t = pw.dio( - # x.astype(np.double), - # fs=self.sample_rate, - # f0_ceil=self.mel_fmax, - # frame_period=1000 * self.hop_length / self.sample_rate, - # ) - # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - # return f0 + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.mel_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) + # f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # spec = self.melspectrogram(x) + return f0 ### Audio Processing ### def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 531523a4..236e78a9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -232,7 +232,7 @@ class Synthesizer(object): # compute a new d_vector from the given clip. if speaker_wav is not None: - speaker_embedding = self.speaker_manager.compute_d_vector_from_clip(speaker_wav) + speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav) use_gl = self.vocoder_model is None diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index edc94d72..a70ebe40 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -11,7 +11,6 @@ def to_camel(text): def setup_model(config: Coqpit): """Load models directly from configuration.""" - print(" > Vocoder Model: {}".format(config.model)) if "discriminator_model" in config and "generator_model" in config: MyModel = importlib.import_module("TTS.vocoder.models.gan") MyModel = getattr(MyModel, "GAN") @@ -28,6 +27,7 @@ def setup_model(config: Coqpit): MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e + print(" > Vocoder Model: {}".format(config.model)) model = MyModel(config) return model diff --git a/docs/source/index.md b/docs/source/index.md index 77d198c0..d842f894 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ models/glow_tts.md models/vits.md + models/fast_pitch.md .. toctree:: :maxdepth: 2 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py new file mode 100644 index 00000000..614e42e0 --- /dev/null +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -0,0 +1,70 @@ +import os + +from TTS.config import BaseAudioConfig, BaseDatasetConfig +from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.tts.configs import FastPitchConfig +from TTS.utils.manage import ModelManager + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = FastPitchConfig( + run_name="fast_pitch_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=50, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + max_seq_len=500000, + output_path=output_path, + datasets=[dataset_config], +) + +# compute alignments +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) + +# train the model +args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) +trainer = Trainer(args, config, output_path, c_logger, tb_logger) +trainer.fit() diff --git a/requirements.txt b/requirements.txt index b92947a0..a87a3c6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 fsspec>=2021.04.0 +pyworld \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index c7930ef9..2b07004f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,12 +1,16 @@ import os +from TTS.config import BaseDatasetConfig from TTS.utils.generic_utils import get_cuda def get_device_id(): use_cuda, _ = get_cuda() if use_cuda: - GPU_ID = "0" + if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "": + GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0] + else: + GPU_ID = "0" else: GPU_ID = "" return GPU_ID @@ -30,3 +34,7 @@ def get_tests_output_path(): def run_cli(command): exit_status = os.system(command) assert exit_status == 0, f" [!] command `{command}` failed." + + +def get_test_data_config(): + return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv") diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 717b2e0f..0fbb6bde 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] - wavs = data[11] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] + wavs = data['waveform'] neg_values = text_input[text_input < 0] check_count = len(neg_values) @@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length @@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] # check mel_spec consistency wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32) @@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data[0] - text_lengths = data[1] - speaker_name = data[2] - linear_input = data[3] - mel_input = data[4] - mel_lengths = data[5] - stop_target = data[6] - item_idx = data[7] + text_input = data['text'] + text_lengths = data['text_lengths'] + speaker_name = data['speaker_names'] + linear_input = data['linear'] + mel_input = data['mel'] + mel_lengths = data['mel_lengths'] + stop_target = data['stop_targets'] + item_idx = data['item_idxs'] if mel_lengths[0] > mel_lengths[1]: idx = 0 diff --git a/tests/test_audio_processor.py b/tests/test_audio_processor.py index 22e965f0..56611692 100644 --- a/tests/test_audio_processor.py +++ b/tests/test_audio_processor.py @@ -181,3 +181,10 @@ class TestAudio(unittest.TestCase): mel_norm = ap.melspectrogram(wav) mel_denorm = ap.denormalize(mel_norm) assert abs(mel_reference - mel_denorm).max() < 1e-4 + + def test_compute_f0(self): # pylint: disable=no-self-use + ap = AudioProcessor(**conf) + wav = ap.load_wav(WAV_FILE) + pitch = ap.compute_f0(wav) + mel = ap.melspectrogram(wav) + assert pitch.shape[0] == mel.shape[1] diff --git a/tests/text_tests/test_japanese_phonemizer.py b/tests/text_tests/test_japanese_phonemizer.py index b3b1ece3..423b79b9 100644 --- a/tests/text_tests/test_japanese_phonemizer.py +++ b/tests/text_tests/test_japanese_phonemizer.py @@ -5,11 +5,13 @@ from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes _TEST_CASES = """ どちらに行きますか?/dochiraniikimasuka? 今日は温泉に、行きます。/kyo:waoNseNni,ikimasu. -「A」から「Z」までです。/AkaraZmadedesu. +「A」から「Z」までです。/e:karazeqtomadedesu. そうですね!/so:desune! クジラは哺乳類です。/kujirawahonyu:ruidesu. ヴィディオを見ます。/bidioomimasu. -ky o: w a o N s e N n i , i k i m a s u ./kyo:waoNseNni,ikimasu. +今日は8月22日です/kyo:wahachigatsuniju:ninichidesu +xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma +値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu """ diff --git a/tests/tts_tests/test_fast_pitch.py b/tests/tts_tests/test_fast_pitch.py new file mode 100644 index 00000000..1975435e --- /dev/null +++ b/tests/tts_tests/test_fast_pitch.py @@ -0,0 +1,47 @@ +import unittest + +import torch as T + +from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch +# pylint: disable=unused-variable + + +class AveragePitchTests(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + pitch = T.rand(1, 1, 128) + + durations = T.randint(1, 5, (1, 21)) + coeff = 128.0 / durations.sum() + durations = T.round(durations * coeff) + diff = 128.0 - durations.sum() + durations[0, -1] += diff + durations = durations.long() + + pitch_avg = average_pitch(pitch, durations) + + index = 0 + for idx, dur in enumerate(durations[0]): + assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5 + index += dur + + +def expand_encoder_outputs_test(): + model = FastPitch(FastPitchArgs(num_chars=10)) + + inputs = T.rand(2, 5, 57) + durations = T.randint(1, 4, (2, 57)) + + x_mask = T.ones(2, 1, 57) + y_mask = T.ones(2, 1, durations.sum(1).max()) + + expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) + + for b in range(durations.shape[0]): + index = 0 + for idx, dur in enumerate(durations[b]): + diff = ( + expanded[b, :, index : index + dur.item()] + - inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape) + ).sum() + assert abs(diff) < 1e-6, diff + index += dur