mirror of https://github.com/coqui-ai/TTS.git
commit
dc2ace3ca0
|
@ -4,14 +4,14 @@
|
||||||
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
||||||
|
|
||||||
[](https://github.com/coqui-ai/TTS/actions)
|
[](https://github.com/coqui-ai/TTS/actions)
|
||||||
[](https://opensource.org/licenses/MPL-2.0)
|
|
||||||
[](https://tts.readthedocs.io/en/latest/)
|
|
||||||
[](https://badge.fury.io/py/TTS)
|
[](https://badge.fury.io/py/TTS)
|
||||||
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
||||||
[](https://pepy.tech/project/tts)
|
[](https://pepy.tech/project/tts)
|
||||||
[](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
|
|
||||||
[](https://zenodo.org/badge/latestdoi/265612440)
|
[](https://zenodo.org/badge/latestdoi/265612440)
|
||||||
|
|
||||||
|
[](https://tts.readthedocs.io/en/latest/)
|
||||||
|
[](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
|
||||||
|
[](https://opensource.org/licenses/MPL-2.0)
|
||||||
|
|
||||||
📰 [**Subscribe to 🐸Coqui.ai Newsletter**](https://coqui.ai/?subscription=true)
|
📰 [**Subscribe to 🐸Coqui.ai Newsletter**](https://coqui.ai/?subscription=true)
|
||||||
|
|
||||||
|
@ -151,3 +151,5 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
||||||
|- vocoder/ (Vocoder models.)
|
|- vocoder/ (Vocoder models.)
|
||||||
|- (same)
|
|- (same)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<img src="https://static.scarf.sh/a.png?x-pxid=503c242f-a253-4fb8-8071-ce1dc1e89999" />
|
|
@ -62,7 +62,16 @@
|
||||||
"default_vocoder": null,
|
"default_vocoder": null,
|
||||||
"commit": "3900448",
|
"commit": "3900448",
|
||||||
"author": "Eren Gölge @erogol",
|
"author": "Eren Gölge @erogol",
|
||||||
"license": "",
|
"license": "TBD",
|
||||||
|
"contact": "egolge@coqui.com"
|
||||||
|
},
|
||||||
|
"fast_pitch": {
|
||||||
|
"description": "FastPitch model trained on LJSpeech using the Aligner Network",
|
||||||
|
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.2/tts_models--en--ljspeech--fast_pitch.zip",
|
||||||
|
"default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2",
|
||||||
|
"commit": "b27b3ba",
|
||||||
|
"author": "Eren Gölge @erogol",
|
||||||
|
"license": "TBD",
|
||||||
"contact": "egolge@coqui.com"
|
"contact": "egolge@coqui.com"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.2.1
|
0.2.2
|
|
@ -8,12 +8,12 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.io import load_checkpoint
|
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_checkpoint
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# pylint: disable=bad-option-value
|
# pylint: disable=bad-option-value
|
||||||
|
@ -27,7 +27,7 @@ Example run:
|
||||||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||||
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||||
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
--dataset_metafile metadata.csv
|
||||||
--data_path /root/LJSpeech-1.1/
|
--data_path /root/LJSpeech-1.1/
|
||||||
--batch_size 32
|
--batch_size 32
|
||||||
--dataset ljspeech
|
--dataset ljspeech
|
||||||
|
@ -76,8 +76,7 @@ Example run:
|
||||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
# TODO: handle multi-speaker
|
# TODO: handle multi-speaker
|
||||||
model = setup_model(C)
|
model = setup_model(C)
|
||||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# data loader
|
# data loader
|
||||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||||
|
@ -127,9 +126,9 @@ Example run:
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
|
|
||||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
|
model_outputs = model.forward(text_input, text_lengths, mel_input)
|
||||||
|
|
||||||
alignments = alignments.detach()
|
alignments = model_outputs["alignments"].detach()
|
||||||
for idx, alignment in enumerate(alignments):
|
for idx, alignment in enumerate(alignments):
|
||||||
item_idx = item_idxs[idx]
|
item_idx = item_idxs[idx]
|
||||||
# interpolate if r > 1
|
# interpolate if r > 1
|
||||||
|
@ -149,10 +148,12 @@ Example run:
|
||||||
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||||||
# set file paths
|
# set file paths
|
||||||
wav_file_name = os.path.basename(item_idx)
|
wav_file_name = os.path.basename(item_idx)
|
||||||
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
|
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
|
||||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||||
# save output
|
# save output
|
||||||
file_paths.append([item_idx, file_path])
|
wav_file_abs_path = os.path.abspath(item_idx)
|
||||||
|
file_abs_path = os.path.abspath(file_path)
|
||||||
|
file_paths.append([wav_file_abs_path, file_abs_path])
|
||||||
np.save(file_path, alignment)
|
np.save(file_path, alignment)
|
||||||
|
|
||||||
# ourput metafile
|
# ourput metafile
|
||||||
|
|
|
@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
d_vectors = data[8]
|
d_vectors = data['d_vectors']
|
||||||
speaker_ids = data[9]
|
speaker_ids = data['speaker_ids']
|
||||||
attn_mask = data[10]
|
attn_mask = data['attns']
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -132,9 +132,8 @@ def inference(
|
||||||
speaker_c = speaker_ids
|
speaker_c = speaker_ids
|
||||||
elif d_vectors is not None:
|
elif d_vectors is not None:
|
||||||
speaker_c = d_vectors
|
speaker_c = d_vectors
|
||||||
|
|
||||||
outputs = model.inference_with_MAS(
|
outputs = model.inference_with_MAS(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
|
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
|
||||||
)
|
)
|
||||||
model_output = outputs["model_outputs"]
|
model_output = outputs["model_outputs"]
|
||||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||||
|
|
|
@ -12,60 +12,89 @@ class BaseAudioConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
fft_size (int):
|
fft_size (int):
|
||||||
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
|
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
|
||||||
|
|
||||||
win_length (int):
|
win_length (int):
|
||||||
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
|
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
|
||||||
```fft_size```. Defaults to 1024.
|
```fft_size```. Defaults to 1024.
|
||||||
|
|
||||||
hop_length (int):
|
hop_length (int):
|
||||||
Number of audio samples between adjacent STFT columns. Defaults to 1024.
|
Number of audio samples between adjacent STFT columns. Defaults to 1024.
|
||||||
|
|
||||||
frame_shift_ms (int):
|
frame_shift_ms (int):
|
||||||
Set ```hop_length``` based on milliseconds and sampling rate.
|
Set ```hop_length``` based on milliseconds and sampling rate.
|
||||||
|
|
||||||
frame_length_ms (int):
|
frame_length_ms (int):
|
||||||
Set ```win_length``` based on milliseconds and sampling rate.
|
Set ```win_length``` based on milliseconds and sampling rate.
|
||||||
|
|
||||||
stft_pad_mode (str):
|
stft_pad_mode (str):
|
||||||
Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
|
Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
|
||||||
|
|
||||||
sample_rate (int):
|
sample_rate (int):
|
||||||
Audio sampling rate. Defaults to 22050.
|
Audio sampling rate. Defaults to 22050.
|
||||||
|
|
||||||
resample (bool):
|
resample (bool):
|
||||||
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
|
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
|
||||||
|
|
||||||
preemphasis (float):
|
preemphasis (float):
|
||||||
Preemphasis coefficient. Defaults to 0.0.
|
Preemphasis coefficient. Defaults to 0.0.
|
||||||
|
|
||||||
ref_level_db (int): 20
|
ref_level_db (int): 20
|
||||||
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
|
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
|
||||||
Defaults to 20.
|
Defaults to 20.
|
||||||
|
|
||||||
do_sound_norm (bool):
|
do_sound_norm (bool):
|
||||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||||
|
|
||||||
|
log_func (str):
|
||||||
|
Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
|
||||||
|
|
||||||
do_trim_silence (bool):
|
do_trim_silence (bool):
|
||||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||||
|
|
||||||
do_amp_to_db_linear (bool, optional):
|
do_amp_to_db_linear (bool, optional):
|
||||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||||
|
|
||||||
do_amp_to_db_mel (bool, optional):
|
do_amp_to_db_mel (bool, optional):
|
||||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
|
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
|
|
||||||
power (float):
|
power (float):
|
||||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||||
artifacts in the synthesized voice. Defaults to 1.5.
|
artifacts in the synthesized voice. Defaults to 1.5.
|
||||||
|
|
||||||
griffin_lim_iters (int):
|
griffin_lim_iters (int):
|
||||||
Number of Griffing Lim iterations. Defaults to 60.
|
Number of Griffing Lim iterations. Defaults to 60.
|
||||||
|
|
||||||
num_mels (int):
|
num_mels (int):
|
||||||
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
|
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
|
||||||
|
|
||||||
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
|
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
|
||||||
It needs to be adjusted for a dataset. Defaults to 0.
|
It needs to be adjusted for a dataset. Defaults to 0.
|
||||||
|
|
||||||
mel_fmax (float):
|
mel_fmax (float):
|
||||||
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
|
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
|
||||||
|
|
||||||
spec_gain (int):
|
spec_gain (int):
|
||||||
Gain applied when converting amplitude to DB. Defaults to 20.
|
Gain applied when converting amplitude to DB. Defaults to 20.
|
||||||
|
|
||||||
signal_norm (bool):
|
signal_norm (bool):
|
||||||
enable/disable signal normalization. Defaults to True.
|
enable/disable signal normalization. Defaults to True.
|
||||||
|
|
||||||
min_level_db (int):
|
min_level_db (int):
|
||||||
minimum db threshold for the computed melspectrograms. Defaults to -100.
|
minimum db threshold for the computed melspectrograms. Defaults to -100.
|
||||||
|
|
||||||
symmetric_norm (bool):
|
symmetric_norm (bool):
|
||||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
|
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
|
||||||
[0, k], Defaults to True.
|
[0, k], Defaults to True.
|
||||||
|
|
||||||
max_norm (float):
|
max_norm (float):
|
||||||
```k``` defining the normalization range. Defaults to 4.0.
|
```k``` defining the normalization range. Defaults to 4.0.
|
||||||
|
|
||||||
clip_norm (bool):
|
clip_norm (bool):
|
||||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||||
|
|
||||||
stats_path (str):
|
stats_path (str):
|
||||||
Path to the computed stats file. Defaults to None.
|
Path to the computed stats file. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
@ -147,15 +176,20 @@ class BaseDatasetConfig(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
name (str):
|
name (str):
|
||||||
Dataset name that defines the preprocessor in use. Defaults to None.
|
Dataset name that defines the preprocessor in use. Defaults to None.
|
||||||
|
|
||||||
path (str):
|
path (str):
|
||||||
Root path to the dataset files. Defaults to None.
|
Root path to the dataset files. Defaults to None.
|
||||||
|
|
||||||
meta_file_train (str):
|
meta_file_train (str):
|
||||||
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
unused_speakers (List):
|
unused_speakers (List):
|
||||||
List of speakers IDs that are not used at the training. Default None.
|
List of speakers IDs that are not used at the training. Default None.
|
||||||
|
|
||||||
meta_file_val (str):
|
meta_file_val (str):
|
||||||
Name of the dataset meta file that defines the instances used at validation.
|
Name of the dataset meta file that defines the instances used at validation.
|
||||||
|
|
||||||
meta_file_attn_mask (str):
|
meta_file_attn_mask (str):
|
||||||
Path to the file that lists the attention mask files used with models that require attention masks to
|
Path to the file that lists the attention mask files used with models that require attention masks to
|
||||||
train the duration predictor.
|
train the duration predictor.
|
||||||
|
@ -298,7 +332,7 @@ class BaseTrainingConfig(Coqpit):
|
||||||
keep_all_best: bool = False
|
keep_all_best: bool = False
|
||||||
keep_after: int = 10000
|
keep_after: int = 10000
|
||||||
# dataloading
|
# dataloading
|
||||||
num_loader_workers: int = None
|
num_loader_workers: int = 0
|
||||||
num_eval_loader_workers: int = 0
|
num_eval_loader_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
# paths
|
# paths
|
||||||
|
|
|
@ -271,6 +271,14 @@ class Trainer:
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
|
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
if self.args.continue_path:
|
||||||
|
if isinstance(self.scheduler, list):
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
scheduler.last_epoch = self.restore_step
|
||||||
|
else:
|
||||||
|
self.scheduler.last_epoch = self.restore_step
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||||
|
@ -291,7 +299,6 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: initialized model.
|
nn.Module: initialized model.
|
||||||
"""
|
"""
|
||||||
# TODO: better model setup
|
|
||||||
try:
|
try:
|
||||||
model = setup_vocoder_model(config)
|
model = setup_vocoder_model(config)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.fast_pitch import FastPitchArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FastPitchConfig(BaseTTSConfig):
|
||||||
|
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs import FastPitchConfig
|
||||||
|
>>> config = FastPitchConfig()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||||
|
|
||||||
|
model_args (Coqpit):
|
||||||
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
|
||||||
|
data_dep_init_steps (int):
|
||||||
|
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||||
|
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||||
|
for the rest. Defaults to 10.
|
||||||
|
|
||||||
|
use_speaker_embedding (bool):
|
||||||
|
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||||
|
in the multi-speaker mode. Defaults to False.
|
||||||
|
|
||||||
|
use_d_vector_file (bool):
|
||||||
|
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_file (str):
|
||||||
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
|
noam_schedule (bool):
|
||||||
|
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||||
|
|
||||||
|
warmup_steps (int):
|
||||||
|
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||||
|
|
||||||
|
lr (float):
|
||||||
|
Initial learning rate. Defaults to `1e-3`.
|
||||||
|
|
||||||
|
wd (float):
|
||||||
|
Weight decay coefficient. Defaults to `1e-7`.
|
||||||
|
|
||||||
|
ssim_loss_alpha (float):
|
||||||
|
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
huber_loss_alpha (float):
|
||||||
|
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
spec_loss_alpha (float):
|
||||||
|
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
pitch_loss_alpha (float):
|
||||||
|
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_loss_alpha (float):
|
||||||
|
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_align_loss_start_step (int):
|
||||||
|
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||||
|
|
||||||
|
min_seq_len (int):
|
||||||
|
Minimum input sequence length to be used at training.
|
||||||
|
|
||||||
|
max_seq_len (int):
|
||||||
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "fast_pitch"
|
||||||
|
# model specific params
|
||||||
|
model_args: FastPitchArgs = field(default_factory=FastPitchArgs)
|
||||||
|
|
||||||
|
# multi-speaker settings
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: str = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
|
||||||
|
# optimizer parameters
|
||||||
|
optimizer: str = "Adam"
|
||||||
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
|
lr_scheduler: str = "NoamLR"
|
||||||
|
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||||
|
lr: float = 1e-4
|
||||||
|
grad_clip: float = 5.0
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
ssim_loss_alpha: float = 1.0
|
||||||
|
dur_loss_alpha: float = 1.0
|
||||||
|
spec_loss_alpha: float = 1.0
|
||||||
|
pitch_loss_alpha: float = 1.0
|
||||||
|
dur_loss_alpha: float = 1.0
|
||||||
|
aligner_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_start_step: int = 20000
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
min_seq_len: int = 13
|
||||||
|
max_seq_len: int = 200
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0: bool = True
|
||||||
|
f0_cache_path: str = None
|
||||||
|
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist.",
|
||||||
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
|
@ -103,6 +103,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
"""Shared parameters among all the tts models.
|
"""Shared parameters among all the tts models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
audio (BaseAudioConfig):
|
audio (BaseAudioConfig):
|
||||||
Audio processor config object instance.
|
Audio processor config object instance.
|
||||||
|
|
||||||
|
@ -140,11 +141,14 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
loss_masking (bool):
|
loss_masking (bool):
|
||||||
enable / disable masking loss values against padded segments of samples in a batch.
|
enable / disable masking loss values against padded segments of samples in a batch.
|
||||||
|
|
||||||
|
sort_by_audio_len (bool):
|
||||||
|
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum input sequence length to be used at training.
|
Minimum sequence length to be used at training.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
|
|
||||||
compute_f0 (int):
|
compute_f0 (int):
|
||||||
(Not in use yet).
|
(Not in use yet).
|
||||||
|
@ -197,6 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
batch_group_size: int = 0
|
batch_group_size: int = 0
|
||||||
loss_masking: bool = None
|
loss_masking: bool = None
|
||||||
# dataloading
|
# dataloading
|
||||||
|
sort_by_audio_len: bool = False
|
||||||
min_seq_len: int = 1
|
min_seq_len: int = 1
|
||||||
max_seq_len: int = float("inf")
|
max_seq_len: int = float("inf")
|
||||||
compute_f0: bool = False
|
compute_f0: bool = False
|
||||||
|
|
|
@ -67,11 +67,14 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec (bool):
|
compute_linear_spec (bool):
|
||||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
sort_by_audio_len (bool):
|
||||||
|
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum text length to be considered for training. Defaults to `13`.
|
Minimum sequnce length to be considered for training. Defaults to `0`.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum text length to be considered for training. Defaults to `500`.
|
Maximum sequnce length to be considered for training. Defaults to `500000`.
|
||||||
|
|
||||||
r (int):
|
r (int):
|
||||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
|
@ -22,6 +22,8 @@ class TTSDataset(Dataset):
|
||||||
compute_linear_spec: bool,
|
compute_linear_spec: bool,
|
||||||
ap: AudioProcessor,
|
ap: AudioProcessor,
|
||||||
meta_data: List[List],
|
meta_data: List[List],
|
||||||
|
compute_f0: bool = False,
|
||||||
|
f0_cache_path: str = None,
|
||||||
characters: Dict = None,
|
characters: Dict = None,
|
||||||
custom_symbols: List = None,
|
custom_symbols: List = None,
|
||||||
add_blank: bool = False,
|
add_blank: bool = False,
|
||||||
|
@ -40,8 +42,7 @@ class TTSDataset(Dataset):
|
||||||
):
|
):
|
||||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||||
|
|
||||||
If you need something different, you can either override or create a new class as the dataset is
|
If you need something different, you can inherit and override.
|
||||||
initialized by the model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs_per_step (int): Number of time frames predicted per step.
|
outputs_per_step (int): Number of time frames predicted per step.
|
||||||
|
@ -54,6 +55,10 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
meta_data (list): List of dataset instances.
|
meta_data (list): List of dataset instances.
|
||||||
|
|
||||||
|
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||||
|
|
||||||
|
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||||
|
|
||||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||||
|
|
||||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||||
|
@ -78,8 +83,8 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||||
|
|
||||||
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
|
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
||||||
the coming iterations. Defaults to None.
|
separate file. Defaults to None.
|
||||||
|
|
||||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||||
|
|
||||||
|
@ -103,6 +108,8 @@ class TTSDataset(Dataset):
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
self.return_wav = return_wav
|
self.return_wav = return_wav
|
||||||
|
self.compute_f0 = compute_f0
|
||||||
|
self.f0_cache_path = f0_cache_path
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
@ -119,8 +126,12 @@ class TTSDataset(Dataset):
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
self.rescue_item_idx = 1
|
self.rescue_item_idx = 1
|
||||||
|
self.pitch_computed = False
|
||||||
|
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
|
if compute_f0:
|
||||||
|
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n > DataLoader initialization")
|
print("\n > DataLoader initialization")
|
||||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||||
|
@ -236,10 +247,16 @@ class TTSDataset(Dataset):
|
||||||
# TODO: find a better fix
|
# TODO: find a better fix
|
||||||
return self.load_data(self.rescue_item_idx)
|
return self.load_data(self.rescue_item_idx)
|
||||||
|
|
||||||
|
pitch = None
|
||||||
|
if self.compute_f0:
|
||||||
|
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||||
|
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"raw_text": raw_text,
|
"raw_text": raw_text,
|
||||||
"text": text,
|
"text": text,
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
|
"pitch": pitch,
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
"item_idx": self.items[idx][1],
|
"item_idx": self.items[idx][1],
|
||||||
"speaker_name": speaker_name,
|
"speaker_name": speaker_name,
|
||||||
|
@ -256,8 +273,8 @@ class TTSDataset(Dataset):
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def compute_input_seq(self, num_workers=0):
|
def compute_input_seq(self, num_workers=0):
|
||||||
"""compute input sequences separately. Call it before
|
"""Compute the input sequences with multi-processing.
|
||||||
passing dataset to data loader."""
|
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||||
if not self.use_phonemes:
|
if not self.use_phonemes:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Computing input sequences ...")
|
print(" | > Computing input sequences ...")
|
||||||
|
@ -339,6 +356,11 @@ class TTSDataset(Dataset):
|
||||||
temp_items = new_items[offset:end_offset]
|
temp_items = new_items[offset:end_offset]
|
||||||
random.shuffle(temp_items)
|
random.shuffle(temp_items)
|
||||||
new_items[offset:end_offset] = temp_items
|
new_items[offset:end_offset] = temp_items
|
||||||
|
|
||||||
|
if len(new_items) == 0:
|
||||||
|
raise RuntimeError(" [!] No items left after filtering.")
|
||||||
|
|
||||||
|
# update items to the new sorted items
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
@ -359,11 +381,23 @@ class TTSDataset(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.load_data(idx)
|
return self.load_data(idx)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sort_batch(batch, text_lengths):
|
||||||
|
"""Sort the batch by the input text length for RNN efficiency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Batch returned by `__getitem__`.
|
||||||
|
text_lengths (List[int]): Lengths of the input character sequences.
|
||||||
|
"""
|
||||||
|
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
|
||||||
|
batch = [batch[idx] for idx in ids_sorted_decreasing]
|
||||||
|
return batch, text_lengths, ids_sorted_decreasing
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
r"""
|
r"""
|
||||||
Perform preprocessing and create a final data batch:
|
Perform preprocessing and create a final data batch:
|
||||||
1. Sort batch instances by text-length
|
1. Sort batch instances by text-length
|
||||||
2. Convert Audio signal to Spectrograms.
|
2. Convert Audio signal to features.
|
||||||
3. PAD sequences wrt r.
|
3. PAD sequences wrt r.
|
||||||
4. Load to Torch.
|
4. Load to Torch.
|
||||||
"""
|
"""
|
||||||
|
@ -371,30 +405,27 @@ class TTSDataset(Dataset):
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
if isinstance(batch[0], collections.abc.Mapping):
|
if isinstance(batch[0], collections.abc.Mapping):
|
||||||
|
|
||||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
text_lengths = np.array([len(d["text"]) for d in batch])
|
||||||
|
|
||||||
# sort items with text input length for RNN efficiency
|
# sort items with text input length for RNN efficiency
|
||||||
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
|
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
|
||||||
|
|
||||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
# convert list of dicts to dict of lists
|
||||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
|
||||||
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
|
||||||
|
|
||||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
if self.d_vector_mapping is not None:
|
if self.d_vector_mapping is not None:
|
||||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_id_mapping:
|
if self.speaker_id_mapping:
|
||||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
# compute features
|
# compute features
|
||||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
|
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
@ -413,7 +444,7 @@ class TTSDataset(Dataset):
|
||||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with longest instance in the batch
|
# PAD sequences with longest instance in the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(batch["text"]).astype(np.int32)
|
||||||
|
|
||||||
# PAD features with longest instance
|
# PAD features with longest instance
|
||||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
@ -422,7 +453,7 @@ class TTSDataset(Dataset):
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
# convert things to pytorch
|
# convert things to pytorch
|
||||||
text_lenghts = torch.LongTensor(text_lenghts)
|
text_lengths = torch.LongTensor(text_lengths)
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
mel = torch.FloatTensor(mel).contiguous()
|
mel = torch.FloatTensor(mel).contiguous()
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
@ -436,7 +467,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
assert mel.shape[1] == linear.shape[1]
|
assert mel.shape[1] == linear.shape[1]
|
||||||
|
@ -447,23 +478,33 @@ class TTSDataset(Dataset):
|
||||||
# format waveforms
|
# format waveforms
|
||||||
wav_padded = None
|
wav_padded = None
|
||||||
if self.return_wav:
|
if self.return_wav:
|
||||||
wav_lengths = [w.shape[0] for w in wav]
|
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
||||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||||
wav_lengths = torch.LongTensor(wav_lengths)
|
wav_lengths = torch.LongTensor(wav_lengths)
|
||||||
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
||||||
for i, w in enumerate(wav):
|
for i, w in enumerate(batch["wav"]):
|
||||||
mel_length = mel_lengths_adjusted[i]
|
mel_length = mel_lengths_adjusted[i]
|
||||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||||
w = w[: mel_length * self.ap.hop_length]
|
w = w[: mel_length * self.ap.hop_length]
|
||||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||||
wav_padded.transpose_(1, 2)
|
wav_padded.transpose_(1, 2)
|
||||||
|
|
||||||
|
# compute f0
|
||||||
|
# TODO: compare perf in collate_fn vs in load_data
|
||||||
|
if self.compute_f0:
|
||||||
|
pitch = prepare_data(batch["pitch"])
|
||||||
|
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||||
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||||
|
else:
|
||||||
|
pitch = None
|
||||||
|
|
||||||
# collate attention alignments
|
# collate attention alignments
|
||||||
if batch[0]["attn"] is not None:
|
if batch["attn"][0] is not None:
|
||||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||||
for idx, attn in enumerate(attns):
|
for idx, attn in enumerate(attns):
|
||||||
pad2 = mel.shape[1] - attn.shape[1]
|
pad2 = mel.shape[1] - attn.shape[1]
|
||||||
pad1 = text.shape[1] - attn.shape[0]
|
pad1 = text.shape[1] - attn.shape[0]
|
||||||
|
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
||||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||||
attns[idx] = attn
|
attns[idx] = attn
|
||||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||||
|
@ -471,21 +512,22 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
attns = None
|
attns = None
|
||||||
# TODO: return dictionary
|
# TODO: return dictionary
|
||||||
return (
|
return {
|
||||||
text,
|
"text": text,
|
||||||
text_lenghts,
|
"text_lengths": text_lengths,
|
||||||
speaker_names,
|
"speaker_names": batch["speaker_name"],
|
||||||
linear,
|
"linear": linear,
|
||||||
mel,
|
"mel": mel,
|
||||||
mel_lengths,
|
"mel_lengths": mel_lengths,
|
||||||
stop_targets,
|
"stop_targets": stop_targets,
|
||||||
item_idxs,
|
"item_idxs": batch["item_idx"],
|
||||||
d_vectors,
|
"d_vectors": d_vectors,
|
||||||
speaker_ids,
|
"speaker_ids": speaker_ids,
|
||||||
attns,
|
"attns": attns,
|
||||||
wav_padded,
|
"waveform": wav_padded,
|
||||||
raw_text,
|
"raw_text": batch["raw_text"],
|
||||||
)
|
"pitch": pitch,
|
||||||
|
}
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
(
|
(
|
||||||
|
@ -495,3 +537,110 @@ class TTSDataset(Dataset):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PitchExtractor:
|
||||||
|
"""Pitch Extractor for computing F0 from wav files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (List[List]): Dataset samples.
|
||||||
|
verbose (bool): Whether to print the progress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
items: List[List],
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.items = items
|
||||||
|
self.verbose = verbose
|
||||||
|
self.mean = None
|
||||||
|
self.std = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_pitch_file_path(wav_file, cache_path):
|
||||||
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||||
|
return pitch_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
||||||
|
wav = ap.load_wav(wav_file)
|
||||||
|
pitch = ap.compute_f0(wav)
|
||||||
|
if pitch_file:
|
||||||
|
np.save(pitch_file, pitch)
|
||||||
|
return pitch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_pitch_stats(pitch_vecs):
|
||||||
|
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
|
||||||
|
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
def normalize_pitch(self, pitch):
|
||||||
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
|
pitch = pitch - self.mean
|
||||||
|
pitch = pitch / self.std
|
||||||
|
pitch[zero_idxs] = 0.0
|
||||||
|
return pitch
|
||||||
|
|
||||||
|
def denormalize_pitch(self, pitch):
|
||||||
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
|
pitch *= self.std
|
||||||
|
pitch += self.mean
|
||||||
|
pitch[zero_idxs] = 0.0
|
||||||
|
return pitch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_or_compute_pitch(ap, wav_file, cache_path):
|
||||||
|
"""
|
||||||
|
compute pitch and return a numpy array of pitch values
|
||||||
|
"""
|
||||||
|
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||||
|
if not os.path.exists(pitch_file):
|
||||||
|
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||||
|
else:
|
||||||
|
pitch = np.load(pitch_file)
|
||||||
|
return pitch.astype(np.float32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pitch_worker(args):
|
||||||
|
item = args[0]
|
||||||
|
ap = args[1]
|
||||||
|
cache_path = args[2]
|
||||||
|
_, wav_file, *_ = item
|
||||||
|
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||||
|
if not os.path.exists(pitch_file):
|
||||||
|
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||||
|
return pitch
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_pitch(self, ap, cache_path, num_workers=0):
|
||||||
|
"""Compute the input sequences with multi-processing.
|
||||||
|
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||||
|
if not os.path.exists(cache_path):
|
||||||
|
os.makedirs(cache_path, exist_ok=True)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(" | > Computing pitch features ...")
|
||||||
|
if num_workers == 0:
|
||||||
|
pitch_vecs = []
|
||||||
|
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
|
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
||||||
|
else:
|
||||||
|
with Pool(num_workers) as p:
|
||||||
|
pitch_vecs = list(
|
||||||
|
tqdm.tqdm(
|
||||||
|
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
||||||
|
total=len(self.items),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
||||||
|
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||||
|
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||||
|
|
||||||
|
def load_pitch_stats(self, cache_path):
|
||||||
|
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||||
|
stats = np.load(stats_path, allow_pickle=True).item()
|
||||||
|
self.mean = stats["mean"].astype(np.float32)
|
||||||
|
self.std = stats["std"].astype(np.float32)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import sys
|
import sys
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -30,7 +31,17 @@ def split_dataset(items):
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
def load_meta_data(datasets, eval_split=True):
|
def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
|
||||||
|
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
|
||||||
|
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
||||||
|
an eval split automatically. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||||
|
"""
|
||||||
meta_data_train_all = []
|
meta_data_train_all = []
|
||||||
meta_data_eval_all = [] if eval_split else None
|
meta_data_eval_all = [] if eval_split else None
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
|
@ -51,7 +62,7 @@ def load_meta_data(datasets, eval_split=True):
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
meta_data_train_all += meta_data_train
|
meta_data_train_all += meta_data_train
|
||||||
# load attention masks for duration predictor training
|
# load attention masks for the duration predictor training
|
||||||
if dataset.meta_file_attn_mask:
|
if dataset.meta_file_attn_mask:
|
||||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||||
for idx, ins in enumerate(meta_data_train_all):
|
for idx, ins in enumerate(meta_data_train_all):
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class AlignmentNetwork(torch.nn.Module):
|
||||||
|
"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention.
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
|
||||||
|
key -> conv1d -> relu -> conv1d -----------------------^
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_query_channels (int): Number of channels in the query network. Defaults to 80.
|
||||||
|
in_key_channels (int): Number of channels in the key network. Defaults to 512.
|
||||||
|
attn_channels (int): Number of inner channels in the attention layers. Defaults to 80.
|
||||||
|
temperature (float): Temperature for the softmax. Defaults to 0.0005.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_query_channels=80,
|
||||||
|
in_key_channels=512,
|
||||||
|
attn_channels=80,
|
||||||
|
temperature=0.0005,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temperature = temperature
|
||||||
|
self.softmax = torch.nn.Softmax(dim=3)
|
||||||
|
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||||
|
|
||||||
|
self.key_layer = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_key_channels,
|
||||||
|
in_key_channels * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=True,
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_layer = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_query_channels,
|
||||||
|
in_query_channels * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=True,
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
|
||||||
|
) -> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
"""Forward pass of the aligner encoder.
|
||||||
|
Shapes:
|
||||||
|
- queries: :math:`[B, C, T_de]`
|
||||||
|
- keys: :math:`[B, C_emb, T_en]`
|
||||||
|
- mask: :math:`[B, T_de]`
|
||||||
|
Output:
|
||||||
|
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
|
||||||
|
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
|
||||||
|
"""
|
||||||
|
key_out = self.key_layer(keys)
|
||||||
|
query_out = self.query_layer(queries)
|
||||||
|
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
|
||||||
|
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
|
||||||
|
if attn_prior is not None:
|
||||||
|
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8)
|
||||||
|
if mask is not None:
|
||||||
|
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||||
|
attn = self.softmax(attn_logp)
|
||||||
|
return attn, attn_logp
|
|
@ -7,17 +7,23 @@ from torch import nn
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||||
Implementation based on "Attention Is All You Need"
|
Implementation based on "Attention Is All You Need"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): embedding size
|
channels (int): embedding size
|
||||||
dropout (float): dropout parameter
|
dropout_p (float): dropout rate applied to the output.
|
||||||
|
max_len (int): maximum sequence length.
|
||||||
|
use_scale (bool): whether to use a learnable scaling coefficient.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, dropout_p=0.0, max_len=5000):
|
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if channels % 2 != 0:
|
if channels % 2 != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
||||||
)
|
)
|
||||||
|
self.use_scale = use_scale
|
||||||
|
if use_scale:
|
||||||
|
self.scale = torch.nn.Parameter(torch.ones(1))
|
||||||
pe = torch.zeros(max_len, channels)
|
pe = torch.zeros(max_len, channels)
|
||||||
position = torch.arange(0, max_len).unsqueeze(1)
|
position = torch.arange(0, max_len).unsqueeze(1)
|
||||||
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
||||||
|
@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module):
|
||||||
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
||||||
else:
|
else:
|
||||||
pos_enc = self.pe[:, :, : x.size(2)]
|
pos_enc = self.pe[:, :, : x.size(2)]
|
||||||
x = x + pos_enc
|
if self.use_scale:
|
||||||
|
x = x + self.scale * pos_enc
|
||||||
|
else:
|
||||||
|
x = x + pos_enc
|
||||||
else:
|
else:
|
||||||
x = x + self.pe[:, :, first_idx:last_idx]
|
if self.use_scale:
|
||||||
|
x = x + self.scale * self.pe[:, :, first_idx:last_idx]
|
||||||
|
else:
|
||||||
|
x = x + self.pe[:, :, first_idx:last_idx]
|
||||||
if hasattr(self, "dropout"):
|
if hasattr(self, "dropout"):
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -15,17 +15,19 @@ class FFTransformer(nn.Module):
|
||||||
self.norm1 = nn.LayerNorm(in_out_channels)
|
self.norm1 = nn.LayerNorm(in_out_channels)
|
||||||
self.norm2 = nn.LayerNorm(in_out_channels)
|
self.norm2 = nn.LayerNorm(in_out_channels)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
self.dropout1 = nn.Dropout(dropout_p)
|
||||||
|
self.dropout2 = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||||
"""😦 ugly looking with all the transposing"""
|
"""😦 ugly looking with all the transposing"""
|
||||||
src = src.permute(2, 0, 1)
|
src = src.permute(2, 0, 1)
|
||||||
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
|
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
src = self.norm1(src + src2)
|
src = self.norm1(src + src2)
|
||||||
# T x B x D -> B x D x T
|
# T x B x D -> B x D x T
|
||||||
src = src.permute(1, 2, 0)
|
src = src.permute(1, 2, 0)
|
||||||
src2 = self.conv2(F.relu(self.conv1(src)))
|
src2 = self.conv2(F.relu(self.conv1(src)))
|
||||||
src2 = self.dropout(src2)
|
src2 = self.dropout2(src2)
|
||||||
src = src + src2
|
src = src + src2
|
||||||
src = src.transpose(1, 2)
|
src = src.transpose(1, 2)
|
||||||
src = self.norm2(src)
|
src = self.norm2(src)
|
||||||
|
@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
TODO: handle multi-speaker
|
TODO: handle multi-speaker
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, C, T]
|
- x: :math:`[B, C, T]`
|
||||||
mask: [B, 1, T] or [B, T]
|
- mask: :math:`[B, 1, T] or [B, T]`
|
||||||
"""
|
"""
|
||||||
if mask is not None and mask.ndim == 3:
|
if mask is not None and mask.ndim == 3:
|
||||||
mask = mask.squeeze(1)
|
mask = mask.squeeze(1)
|
||||||
|
@ -65,3 +67,21 @@ class FFTransformerBlock(nn.Module):
|
||||||
alignments.append(align.unsqueeze(1))
|
alignments.append(align.unsqueeze(1))
|
||||||
alignments = torch.cat(alignments, 1)
|
alignments = torch.cat(alignments, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FFTDurationPredictor:
|
||||||
|
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
|
||||||
|
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
||||||
|
self.proj = nn.Linear(in_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, C, T]`
|
||||||
|
- mask: :math:`[B, 1, T]`
|
||||||
|
|
||||||
|
TODO: Handle the cond input
|
||||||
|
"""
|
||||||
|
x = self.fft(x, mask=mask)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
|
@ -21,8 +21,10 @@ def convert_pad_shape(pad_shape):
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
def generate_path(duration, mask):
|
||||||
"""
|
"""
|
||||||
duration: [b, t_x]
|
Shapes:
|
||||||
mask: [b, t_x, t_y]
|
- duration: :math:`[B, T_en]`
|
||||||
|
- mask: :math:'[B, T_en, T_de]`
|
||||||
|
- path: :math:`[B, T_en, T_de]`
|
||||||
"""
|
"""
|
||||||
device = duration.device
|
device = duration.device
|
||||||
b, t_x, t_y = mask.shape
|
b, t_x, t_y = mask.shape
|
||||||
|
@ -45,8 +47,9 @@ def maximum_path(value, mask):
|
||||||
|
|
||||||
def maximum_path_cython(value, mask):
|
def maximum_path_cython(value, mask):
|
||||||
"""Cython optimised version.
|
"""Cython optimised version.
|
||||||
value: [b, t_x, t_y]
|
Shapes:
|
||||||
mask: [b, t_x, t_y]
|
- value: :math:`[B, T_en, T_de]`
|
||||||
|
- mask: :math:`[B, T_en, T_de]`
|
||||||
"""
|
"""
|
||||||
value = value * mask
|
value = value * mask
|
||||||
device = value.device
|
device = value.device
|
||||||
|
|
|
@ -69,9 +69,9 @@ class MSELossMasked(nn.Module):
|
||||||
length: A Variable containing a LongTensor of size (batch,)
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
which contains the length of each data in a batch.
|
which contains the length of each data in a batch.
|
||||||
Shapes:
|
Shapes:
|
||||||
x: B x T X D
|
- x: :math:`[B, T, D]`
|
||||||
target: B x T x D
|
- target: :math:`[B, T, D]`
|
||||||
length: B
|
- length: :math:`B`
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
|
@ -658,3 +658,109 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
loss = loss + return_dict["loss_disc"]
|
loss = loss + return_dict["loss_disc"]
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardSumLoss(nn.Module):
|
||||||
|
def __init__(self, blank_logprob=-1):
|
||||||
|
super().__init__()
|
||||||
|
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||||
|
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
|
||||||
|
self.blank_logprob = blank_logprob
|
||||||
|
|
||||||
|
def forward(self, attn_logprob, in_lens, out_lens):
|
||||||
|
key_lens = in_lens
|
||||||
|
query_lens = out_lens
|
||||||
|
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
for bid in range(attn_logprob.shape[0]):
|
||||||
|
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||||
|
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
|
||||||
|
|
||||||
|
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||||
|
loss = self.ctc_loss(
|
||||||
|
curr_logprob,
|
||||||
|
target_seq,
|
||||||
|
input_lengths=query_lens[bid : bid + 1],
|
||||||
|
target_lengths=key_lens[bid : bid + 1],
|
||||||
|
)
|
||||||
|
total_loss = total_loss + loss
|
||||||
|
|
||||||
|
total_loss = total_loss / attn_logprob.shape[0]
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
|
class FastPitchLoss(nn.Module):
|
||||||
|
def __init__(self, c):
|
||||||
|
super().__init__()
|
||||||
|
self.spec_loss = MSELossMasked(False)
|
||||||
|
self.ssim = SSIMLoss()
|
||||||
|
self.dur_loss = MSELossMasked(False)
|
||||||
|
self.pitch_loss = MSELossMasked(False)
|
||||||
|
if c.model_args.use_aligner:
|
||||||
|
self.aligner_loss = ForwardSumLoss()
|
||||||
|
|
||||||
|
self.spec_loss_alpha = c.spec_loss_alpha
|
||||||
|
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||||
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
|
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||||
|
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||||
|
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||||
|
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||||
|
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||||
|
"""
|
||||||
|
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||||
|
return -log_sum / alignment_hard.sum()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_output,
|
||||||
|
decoder_target,
|
||||||
|
decoder_output_lens,
|
||||||
|
dur_output,
|
||||||
|
dur_target,
|
||||||
|
pitch_output,
|
||||||
|
pitch_target,
|
||||||
|
input_lens,
|
||||||
|
alignment_logprob=None,
|
||||||
|
alignment_hard=None,
|
||||||
|
alignment_soft=None,
|
||||||
|
):
|
||||||
|
loss = 0
|
||||||
|
return_dict = {}
|
||||||
|
if self.ssim_loss_alpha > 0:
|
||||||
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
loss = loss + self.ssim_loss_alpha * ssim_loss
|
||||||
|
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
|
||||||
|
|
||||||
|
if self.spec_loss_alpha > 0:
|
||||||
|
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||||
|
loss = loss + self.spec_loss_alpha * spec_loss
|
||||||
|
return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
|
||||||
|
|
||||||
|
if self.dur_loss_alpha > 0:
|
||||||
|
log_dur_tgt = torch.log(dur_target.float() + 1)
|
||||||
|
dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
|
||||||
|
loss = loss + self.dur_loss_alpha * dur_loss
|
||||||
|
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
|
||||||
|
|
||||||
|
if self.pitch_loss_alpha > 0:
|
||||||
|
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
|
||||||
|
loss = loss + self.pitch_loss_alpha * pitch_loss
|
||||||
|
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||||
|
|
||||||
|
if self.aligner_loss_alpha > 0:
|
||||||
|
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||||
|
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||||
|
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||||
|
|
||||||
|
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||||
|
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||||
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
|
||||||
|
return_dict["loss"] = loss
|
||||||
|
return return_dict
|
||||||
|
|
|
@ -13,7 +13,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
@ -355,9 +354,6 @@ class AlignTTS(BaseTTS):
|
||||||
phase=self.phase,
|
phase=self.phase,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
|
||||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
|
||||||
loss_dict["align_error"] = align_error
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(
|
def train_log(
|
||||||
|
|
|
@ -78,7 +78,9 @@ class BaseTacotron(BaseTTS):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
if aux_input:
|
||||||
|
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||||
|
return None
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# INIT FUNCTIONS
|
# INIT FUNCTIONS
|
||||||
|
|
|
@ -104,18 +104,19 @@ class BaseTTS(BaseModel):
|
||||||
Dict: [description]
|
Dict: [description]
|
||||||
"""
|
"""
|
||||||
# setup input batch
|
# setup input batch
|
||||||
text_input = batch[0]
|
text_input = batch["text"]
|
||||||
text_lengths = batch[1]
|
text_lengths = batch["text_lengths"]
|
||||||
speaker_names = batch[2]
|
speaker_names = batch["speaker_names"]
|
||||||
linear_input = batch[3]
|
linear_input = batch["linear"]
|
||||||
mel_input = batch[4]
|
mel_input = batch["mel"]
|
||||||
mel_lengths = batch[5]
|
mel_lengths = batch["mel_lengths"]
|
||||||
stop_targets = batch[6]
|
stop_targets = batch["stop_targets"]
|
||||||
item_idx = batch[7]
|
item_idx = batch["item_idxs"]
|
||||||
d_vectors = batch[8]
|
d_vectors = batch["d_vectors"]
|
||||||
speaker_ids = batch[9]
|
speaker_ids = batch["speaker_ids"]
|
||||||
attn_mask = batch[10]
|
attn_mask = batch["attns"]
|
||||||
waveform = batch[11]
|
waveform = batch["waveform"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
max_text_length = torch.max(text_lengths.float())
|
max_text_length = torch.max(text_lengths.float())
|
||||||
max_spec_length = torch.max(mel_lengths.float())
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -162,6 +163,7 @@ class BaseTTS(BaseModel):
|
||||||
"max_spec_length": float(max_spec_length),
|
"max_spec_length": float(max_spec_length),
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
"waveform": waveform,
|
"waveform": waveform,
|
||||||
|
"pitch": pitch,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
|
@ -199,6 +201,8 @@ class BaseTTS(BaseModel):
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
text_cleaner=config.text_cleaner,
|
||||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||||
|
compute_f0=config.get("compute_f0", False),
|
||||||
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
meta_data=data_items,
|
meta_data=data_items,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
characters=config.characters,
|
characters=config.characters,
|
||||||
|
@ -245,6 +249,21 @@ class BaseTTS(BaseModel):
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||||
|
|
||||||
|
# compute pitch frames and write to files.
|
||||||
|
if config.compute_f0 and rank in [None, 0]:
|
||||||
|
if not os.path.exists(config.f0_cache_path):
|
||||||
|
dataset.pitch_extractor.compute_pitch(
|
||||||
|
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
# halt DDP processes for the main process to finish computing the F0 cache
|
||||||
|
if num_gpus > 1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# load pitch stats computed above by all the workers
|
||||||
|
if config.compute_f0:
|
||||||
|
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,697 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
|
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||||
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FastPitchArgs(Coqpit):
|
||||||
|
"""Fast Pitch Model arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
num_chars (int):
|
||||||
|
Number of characters in the vocabulary. Defaults to 100.
|
||||||
|
|
||||||
|
out_channels (int):
|
||||||
|
Number of output channels. Defaults to 80.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
Number of base hidden channels of the model. Defaults to 512.
|
||||||
|
|
||||||
|
num_speakers (int):
|
||||||
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
duration_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the duration predictor. Defaults to 256.
|
||||||
|
|
||||||
|
duration_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the duration predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
duration_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the duration predictor. Defaults to 3.
|
||||||
|
|
||||||
|
pitch_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the pitch predictor. Defaults to 256.
|
||||||
|
|
||||||
|
pitch_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the pitch predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
pitch_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
pitch_embedding_kernel_size (int):
|
||||||
|
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
positional_encoding (bool):
|
||||||
|
Whether to use positional encoding. Defaults to True.
|
||||||
|
|
||||||
|
positional_encoding_use_scale (bool):
|
||||||
|
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
|
||||||
|
|
||||||
|
length_scale (int):
|
||||||
|
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
|
||||||
|
|
||||||
|
encoder_type (str):
|
||||||
|
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
|
||||||
|
Defaults to `fftransformer` as in the paper.
|
||||||
|
|
||||||
|
encoder_params (dict):
|
||||||
|
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||||
|
|
||||||
|
decoder_type (str):
|
||||||
|
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
|
||||||
|
Defaults to `fftransformer` as in the paper.
|
||||||
|
|
||||||
|
decoder_params (str):
|
||||||
|
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||||
|
|
||||||
|
use_d_vetor (bool):
|
||||||
|
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Number of channels of the d-vectors. Defaults to 0.
|
||||||
|
|
||||||
|
detach_duration_predictor (bool):
|
||||||
|
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
||||||
|
does not pass to the earlier layers. Defaults to True.
|
||||||
|
|
||||||
|
max_duration (int):
|
||||||
|
Maximum duration accepted by the model. Defaults to 75.
|
||||||
|
|
||||||
|
use_aligner (bool):
|
||||||
|
Use aligner network to learn the text to speech alignment. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_chars: int = None
|
||||||
|
out_channels: int = 80
|
||||||
|
hidden_channels: int = 384
|
||||||
|
num_speakers: int = 0
|
||||||
|
duration_predictor_hidden_channels: int = 256
|
||||||
|
duration_predictor_kernel_size: int = 3
|
||||||
|
duration_predictor_dropout_p: float = 0.1
|
||||||
|
pitch_predictor_hidden_channels: int = 256
|
||||||
|
pitch_predictor_kernel_size: int = 3
|
||||||
|
pitch_predictor_dropout_p: float = 0.1
|
||||||
|
pitch_embedding_kernel_size: int = 3
|
||||||
|
positional_encoding: bool = True
|
||||||
|
poisitonal_encoding_use_scale: bool = True
|
||||||
|
length_scale: int = 1
|
||||||
|
encoder_type: str = "fftransformer"
|
||||||
|
encoder_params: dict = field(
|
||||||
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
|
)
|
||||||
|
decoder_type: str = "fftransformer"
|
||||||
|
decoder_params: dict = field(
|
||||||
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
|
)
|
||||||
|
use_d_vector: bool = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
detach_duration_predictor: bool = False
|
||||||
|
max_duration: int = 75
|
||||||
|
use_aligner: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class FastPitch(BaseTTS):
|
||||||
|
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
|
||||||
|
|
||||||
|
Paper::
|
||||||
|
https://arxiv.org/abs/2006.06873
|
||||||
|
|
||||||
|
Paper abstract::
|
||||||
|
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
|
||||||
|
frequency contours. The model predicts pitch contours during inference. By altering these predictions,
|
||||||
|
the generated speech can be more expressive, better match the semantic of the utterance, and in the end
|
||||||
|
more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech
|
||||||
|
that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall
|
||||||
|
quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead,
|
||||||
|
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
|
||||||
|
factor for mel-spectrogram synthesis of a typical utterance."
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model coqpit class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs
|
||||||
|
>>> config = FastPitchArgs()
|
||||||
|
>>> model = FastPitch(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
def __init__(self, config: Coqpit):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# don't use isintance not to import recursively
|
||||||
|
if config.__class__.__name__ == "FastPitchConfig":
|
||||||
|
if "characters" in config:
|
||||||
|
# loading from FasrPitchConfig
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
self.args = self.config.model_args
|
||||||
|
else:
|
||||||
|
# loading from FastPitchArgs
|
||||||
|
self.config = config
|
||||||
|
self.args = config.model_args
|
||||||
|
elif isinstance(config, FastPitchArgs):
|
||||||
|
self.args = config
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a VitsConfig or Vitsself.args")
|
||||||
|
|
||||||
|
self.max_duration = self.args.max_duration
|
||||||
|
self.use_aligner = self.args.use_aligner
|
||||||
|
self.use_binary_alignment_loss = False
|
||||||
|
|
||||||
|
self.length_scale = (
|
||||||
|
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.encoder_type,
|
||||||
|
self.args.encoder_params,
|
||||||
|
self.args.d_vector_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.positional_encoding:
|
||||||
|
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
self.args.out_channels,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.decoder_type,
|
||||||
|
self.args.decoder_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.duration_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.args.d_vector_dim,
|
||||||
|
self.args.duration_predictor_hidden_channels,
|
||||||
|
self.args.duration_predictor_kernel_size,
|
||||||
|
self.args.duration_predictor_dropout_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pitch_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.args.d_vector_dim,
|
||||||
|
self.args.pitch_predictor_hidden_channels,
|
||||||
|
self.args.pitch_predictor_kernel_size,
|
||||||
|
self.args.pitch_predictor_dropout_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pitch_emb = nn.Conv1d(
|
||||||
|
1,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
kernel_size=self.args.pitch_embedding_kernel_size,
|
||||||
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.num_speakers > 1 and not self.args.use_d_vector:
|
||||||
|
# speaker embedding layer
|
||||||
|
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
|
||||||
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
|
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
|
||||||
|
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
|
|
||||||
|
if self.args.use_aligner:
|
||||||
|
self.aligner = AlignmentNetwork(
|
||||||
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_attn(dr, x_mask, y_mask=None):
|
||||||
|
"""Generate an attention mask from the durations.
|
||||||
|
|
||||||
|
Shapes
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
- y_mask: :math:`(B, T_{de})`
|
||||||
|
"""
|
||||||
|
# compute decode mask from the durations
|
||||||
|
if y_mask is None:
|
||||||
|
y_lengths = dr.sum(1).long()
|
||||||
|
y_lengths[y_lengths < 1] = 1
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||||
|
return attn
|
||||||
|
|
||||||
|
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||||
|
"""Generate attention alignment map from durations and
|
||||||
|
expand encoder outputs
|
||||||
|
|
||||||
|
Shapes
|
||||||
|
- en: :math:`(B, D_{en}, T_{en})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
- y_mask: :math:`(B, T_{de})`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- encoder output: :math:`[a,b,c,d]`
|
||||||
|
- durations: :math:`[1, 3, 2, 1]`
|
||||||
|
|
||||||
|
- expanded: :math:`[a, b, b, b, c, c, d]`
|
||||||
|
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
|
||||||
|
"""
|
||||||
|
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||||
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
|
||||||
|
return o_en_ex, attn
|
||||||
|
|
||||||
|
def format_durations(self, o_dr_log, x_mask):
|
||||||
|
"""Format predicted durations.
|
||||||
|
1. Convert to linear scale from log scale
|
||||||
|
2. Apply the length scale for speed adjustment
|
||||||
|
3. Apply masking.
|
||||||
|
4. Cast 0 durations to 1.
|
||||||
|
5. Round the duration values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_dr_log: Log scale durations.
|
||||||
|
x_mask: Input text mask.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_dr_log: :math:`(B, T_{de})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
||||||
|
o_dr[o_dr < 1] = 1.0
|
||||||
|
o_dr = torch.round(o_dr)
|
||||||
|
return o_dr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat_speaker_embedding(o_en, g):
|
||||||
|
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
|
||||||
|
o_en = torch.cat([o_en, g_exp], 1)
|
||||||
|
return o_en
|
||||||
|
|
||||||
|
def _sum_speaker_embedding(self, x, g):
|
||||||
|
# project g to decoder dim.
|
||||||
|
if hasattr(self, "proj_g"):
|
||||||
|
g = self.proj_g(g)
|
||||||
|
return x + g
|
||||||
|
|
||||||
|
def _forward_encoder(
|
||||||
|
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Encoding forward pass.
|
||||||
|
|
||||||
|
1. Embed speaker IDs if multi-speaker mode.
|
||||||
|
2. Embed character sequences.
|
||||||
|
3. Run the encoder network.
|
||||||
|
4. Concat speaker embedding to the encoder output for the duration predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input sequence IDs.
|
||||||
|
x_mask (torch.FloatTensor): Input squence mask.
|
||||||
|
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
||||||
|
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
|
||||||
|
character embeddings
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- g: :math:`(B, C)`
|
||||||
|
"""
|
||||||
|
if hasattr(self, "emb_g"):
|
||||||
|
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||||
|
if g is not None:
|
||||||
|
g = g.unsqueeze(-1)
|
||||||
|
# [B, T, C]
|
||||||
|
x_emb = self.emb(x)
|
||||||
|
# encoder pass
|
||||||
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||||
|
# speaker conditioning for duration predictor
|
||||||
|
if g is not None:
|
||||||
|
o_en_dp = self._concat_speaker_embedding(o_en, g)
|
||||||
|
else:
|
||||||
|
o_en_dp = o_en
|
||||||
|
return o_en, o_en_dp, x_mask, g, x_emb
|
||||||
|
|
||||||
|
def _forward_decoder(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
dr: torch.IntTensor,
|
||||||
|
x_mask: torch.FloatTensor,
|
||||||
|
y_lengths: torch.IntTensor,
|
||||||
|
g: torch.FloatTensor,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Decoding forward pass.
|
||||||
|
|
||||||
|
1. Compute the decoder output mask
|
||||||
|
2. Expand encoder output with the durations.
|
||||||
|
3. Apply position encoding.
|
||||||
|
4. Add speaker embeddings if multi-speaker mode.
|
||||||
|
5. Run the decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
dr (torch.IntTensor): Ground truth durations or alignment network durations.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
y_lengths (torch.IntTensor): Output sequence lengths.
|
||||||
|
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||||
|
"""
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||||
|
# expand o_en with durations
|
||||||
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||||
|
# positional encoding
|
||||||
|
if hasattr(self, "pos_encoder"):
|
||||||
|
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||||
|
# speaker embedding
|
||||||
|
if g is not None:
|
||||||
|
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
|
||||||
|
# decoder pass
|
||||||
|
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||||
|
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
||||||
|
|
||||||
|
def _forward_pitch_predictor(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
x_mask: torch.IntTensor,
|
||||||
|
pitch: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Pitch predictor forward pass.
|
||||||
|
|
||||||
|
1. Predict pitch from encoder outputs.
|
||||||
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||||
|
3. Embed average pitch values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||||
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_en: :math:`(B, C, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- pitch: :math:`(B, 1, T_{de})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_pitch = self.pitch_predictor(o_en, x_mask)
|
||||||
|
if pitch is not None:
|
||||||
|
avg_pitch = average_pitch(pitch, dr)
|
||||||
|
o_pitch_emb = self.pitch_emb(avg_pitch)
|
||||||
|
return o_pitch_emb, o_pitch, avg_pitch
|
||||||
|
o_pitch_emb = self.pitch_emb(o_pitch)
|
||||||
|
return o_pitch_emb, o_pitch
|
||||||
|
|
||||||
|
def _forward_aligner(
|
||||||
|
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
||||||
|
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Aligner forward pass.
|
||||||
|
|
||||||
|
1. Compute a mask to apply to the attention map.
|
||||||
|
2. Run the alignment network.
|
||||||
|
3. Apply MAS to compute the hard alignment map.
|
||||||
|
4. Compute the durations from the hard alignment map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.FloatTensor): Input sequence.
|
||||||
|
y (torch.FloatTensor): Output sequence.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
y_mask (torch.IntTensor): Output sequence mask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
|
||||||
|
hard alignment map.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_en, C_en]`
|
||||||
|
- y: :math:`[B, T_de, C_de]`
|
||||||
|
- x_mask: :math:`[B, 1, T_en]`
|
||||||
|
- y_mask: :math:`[B, 1, T_de]`
|
||||||
|
|
||||||
|
- o_alignment_dur: :math:`[B, T_en]`
|
||||||
|
- alignment_soft: :math:`[B, T_en, T_de]`
|
||||||
|
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
|
||||||
|
- alignment_mas: :math:`[B, T_en, T_de]`
|
||||||
|
"""
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
|
||||||
|
alignment_mas = maximum_path(
|
||||||
|
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||||
|
)
|
||||||
|
o_alignment_dur = torch.sum(alignment_mas, -1).int()
|
||||||
|
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
||||||
|
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.LongTensor,
|
||||||
|
x_lengths: torch.LongTensor,
|
||||||
|
y_lengths: torch.LongTensor,
|
||||||
|
y: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
pitch: torch.FloatTensor = None,
|
||||||
|
aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument
|
||||||
|
) -> Dict:
|
||||||
|
"""Model's forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input character sequences.
|
||||||
|
x_lengths (torch.LongTensor): Input sequence lengths.
|
||||||
|
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
|
||||||
|
y (torch.FloatTensor): Spectrogram frames. Defaults to None.
|
||||||
|
dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None.
|
||||||
|
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None.
|
||||||
|
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_max]`
|
||||||
|
- x_lengths: :math:`[B]`
|
||||||
|
- y_lengths: :math:`[B]`
|
||||||
|
- y: :math:`[B, T_max2]`
|
||||||
|
- dr: :math:`[B, T_max]`
|
||||||
|
- g: :math:`[B, C]`
|
||||||
|
- pitch: :math:`[B, 1, T]`
|
||||||
|
"""
|
||||||
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
|
# compute sequence masks
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype)
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype)
|
||||||
|
# encoder pass
|
||||||
|
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||||
|
# duration predictor pass
|
||||||
|
if self.args.detach_duration_predictor:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
|
else:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||||
|
# generate attn mask from predicted durations
|
||||||
|
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||||
|
# aligner pass
|
||||||
|
if self.use_aligner:
|
||||||
|
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
|
||||||
|
x_emb, y, x_mask, y_mask
|
||||||
|
)
|
||||||
|
dr = o_alignment_dur
|
||||||
|
# pitch predictor pass
|
||||||
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
||||||
|
o_en = o_en + o_pitch_emb
|
||||||
|
# decoder pass
|
||||||
|
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
|
||||||
|
outputs = {
|
||||||
|
"model_outputs": o_de,
|
||||||
|
"durations_log": o_dr_log.squeeze(1),
|
||||||
|
"durations": o_dr.squeeze(1),
|
||||||
|
"attn_durations": o_attn, # for visualization
|
||||||
|
"pitch_avg": o_pitch,
|
||||||
|
"pitch_avg_gt": avg_pitch,
|
||||||
|
"alignments": attn,
|
||||||
|
"alignment_soft": alignment_soft.transpose(1, 2),
|
||||||
|
"alignment_mas": alignment_mas.transpose(1, 2),
|
||||||
|
"o_alignment_dur": o_alignment_dur,
|
||||||
|
"alignment_logprob": alignment_logprob,
|
||||||
|
"x_mask": x_mask,
|
||||||
|
"y_mask": y_mask,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
|
"""Model's inference pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input character sequence.
|
||||||
|
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: [B, T_max]
|
||||||
|
- x_lengths: [B]
|
||||||
|
- g: [B, C]
|
||||||
|
"""
|
||||||
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||||
|
# encoder pass
|
||||||
|
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||||
|
# duration predictor pass
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
|
y_lengths = o_dr.sum(1)
|
||||||
|
# pitch predictor pass
|
||||||
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
||||||
|
o_en = o_en + o_pitch_emb
|
||||||
|
# decoder pass
|
||||||
|
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
|
||||||
|
outputs = {
|
||||||
|
"model_outputs": o_de,
|
||||||
|
"alignments": attn,
|
||||||
|
"pitch": o_pitch,
|
||||||
|
"durations_log": o_dr_log,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
text_input = batch["text_input"]
|
||||||
|
text_lengths = batch["text_lengths"]
|
||||||
|
mel_input = batch["mel_input"]
|
||||||
|
mel_lengths = batch["mel_lengths"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
durations = batch["durations"]
|
||||||
|
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
outputs = self.forward(
|
||||||
|
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||||
|
)
|
||||||
|
# use aligner's output as the duration target
|
||||||
|
if self.use_aligner:
|
||||||
|
durations = outputs["o_alignment_dur"]
|
||||||
|
# use float32 in AMP
|
||||||
|
with autocast(enabled=False):
|
||||||
|
# compute loss
|
||||||
|
loss_dict = criterion(
|
||||||
|
decoder_output=outputs["model_outputs"],
|
||||||
|
decoder_target=mel_input,
|
||||||
|
decoder_output_lens=mel_lengths,
|
||||||
|
dur_output=outputs["durations_log"],
|
||||||
|
dur_target=durations,
|
||||||
|
pitch_output=outputs["pitch_avg"],
|
||||||
|
pitch_target=outputs["pitch_avg_gt"],
|
||||||
|
input_lens=text_lengths,
|
||||||
|
alignment_logprob=outputs["alignment_logprob"],
|
||||||
|
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||||
|
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||||
|
)
|
||||||
|
# compute duration error
|
||||||
|
durations_pred = outputs["durations"]
|
||||||
|
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
||||||
|
loss_dict["duration_error"] = duration_error
|
||||||
|
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||||
|
model_outputs = outputs["model_outputs"]
|
||||||
|
alignments = outputs["alignments"]
|
||||||
|
mel_input = batch["mel_input"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||||
|
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
pitch = pitch[0, 0].data.cpu().numpy()
|
||||||
|
|
||||||
|
# TODO: denormalize before plotting
|
||||||
|
pitch = abs(pitch)
|
||||||
|
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||||
|
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# plot the attention mask computed from the predicted durations
|
||||||
|
if "attn_durations" in outputs:
|
||||||
|
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
||||||
|
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
return figures, {"audio": train_audio}
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
||||||
|
def get_criterion(self):
|
||||||
|
from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
return FastPitchLoss(self.config)
|
||||||
|
|
||||||
|
def on_train_step_start(self, trainer):
|
||||||
|
"""Enable binary alignment loss when needed"""
|
||||||
|
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||||
|
self.use_binary_alignment_loss = True
|
||||||
|
|
||||||
|
|
||||||
|
def average_pitch(pitch, durs):
|
||||||
|
"""Compute the average pitch value for each input character based on the durations.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- pitch: :math:`[B, 1, T_de]`
|
||||||
|
- durs: :math:`[B, T_en]`
|
||||||
|
"""
|
||||||
|
|
||||||
|
durs_cums_ends = torch.cumsum(durs, dim=1).long()
|
||||||
|
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
|
||||||
|
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
|
||||||
|
pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0))
|
||||||
|
|
||||||
|
bs, l = durs_cums_ends.size()
|
||||||
|
n_formants = pitch.size(1)
|
||||||
|
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
|
||||||
|
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)
|
||||||
|
|
||||||
|
pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float()
|
||||||
|
pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float()
|
||||||
|
|
||||||
|
pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems)
|
||||||
|
return pitch_avg
|
|
@ -10,7 +10,6 @@ from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -110,6 +109,10 @@ class GlowTTS(BaseTTS):
|
||||||
# init speaker manager
|
# init speaker manager
|
||||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
self.external_d_vector_dim = config.d_vector_dim
|
||||||
|
else:
|
||||||
|
self.external_d_vector_dim = 0
|
||||||
# init speaker embedding layer
|
# init speaker embedding layer
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = self.c_in_channels
|
self.embedded_speaker_dim = self.c_in_channels
|
||||||
|
@ -130,22 +133,22 @@ class GlowTTS(BaseTTS):
|
||||||
return y_mean, y_log_scale, o_attn_dur
|
return y_mean, y_log_scale, o_attn_dur
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
|
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, T]`
|
- x: :math:`[B, T]`
|
||||||
- x_lenghts::math:` B`
|
- x_lenghts::math:`B`
|
||||||
- y: :math:`[B, T, C]`
|
- y: :math:`[B, T, C]`
|
||||||
- y_lengths::math:` B`
|
- y_lengths::math:`B`
|
||||||
- g: :math:`[B, C] or B`
|
- g: :math:`[B, C] or B`
|
||||||
"""
|
"""
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||||
if g is not None:
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if self.d_vector_dim:
|
if not self.use_d_vector_file:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -182,7 +185,7 @@ class GlowTTS(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_with_MAS(
|
def inference_with_MAS(
|
||||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
|
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
It's similar to the teacher forcing in Tacotron.
|
It's similar to the teacher forcing in Tacotron.
|
||||||
|
@ -199,12 +202,11 @@ class GlowTTS(BaseTTS):
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||||
if g is not None:
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if self.external_d_vector_dim:
|
if not self.use_d_vector_file:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
|
@ -244,7 +246,7 @@ class GlowTTS(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decoder_inference(
|
def decoder_inference(
|
||||||
self, y, y_lengths=None, aux_input={"d_vectors": None}
|
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -276,7 +278,7 @@ class GlowTTS(BaseTTS):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
|
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
|
||||||
x_lengths = aux_input["x_lengths"]
|
x_lengths = aux_input["x_lengths"]
|
||||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||||
|
|
||||||
|
@ -327,8 +329,9 @@ class GlowTTS(BaseTTS):
|
||||||
mel_input = batch["mel_input"]
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch["mel_lengths"]
|
mel_lengths = batch["mel_lengths"]
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
|
||||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
|
||||||
|
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs["model_outputs"],
|
outputs["model_outputs"],
|
||||||
|
@ -341,9 +344,6 @@ class GlowTTS(BaseTTS):
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute alignment error (the lower the better )
|
|
||||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
|
||||||
loss_dict["align_error"] = align_error
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||||
|
|
|
@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4
|
||||||
x_lengths = T
|
x_lengths = T
|
||||||
max_idxs = x_lengths - segment_size + 1
|
max_idxs = x_lengths - segment_size + 1
|
||||||
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
|
||||||
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
|
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
|
||||||
ret = segment(x, ids_str, segment_size)
|
ret = segment(x, segment_indices, segment_size)
|
||||||
return ret, ids_str
|
return ret, segment_indices
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -7,7 +7,7 @@ def alignment_diagonal_score(alignments, binary=False):
|
||||||
binary (bool): if True, ignore scores and consider attention
|
binary (bool): if True, ignore scores and consider attention
|
||||||
as a binary mask.
|
as a binary mask.
|
||||||
Shape:
|
Shape:
|
||||||
alignments : batch x decoder_steps x encoder_steps
|
- alignments : :math:`[B, T_de, T_en]`
|
||||||
"""
|
"""
|
||||||
maxs = alignments.max(dim=1)[0]
|
maxs = alignments.max(dim=1)[0]
|
||||||
if binary:
|
if binary:
|
||||||
|
|
|
@ -45,12 +45,10 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
||||||
# TO REVIEW : How to have a good implementation for this?
|
# TO REVIEW : How to have a good implementation for this?
|
||||||
if language == "zh-CN":
|
if language == "zh-CN":
|
||||||
ph = chinese_text_to_phonemes(text)
|
ph = chinese_text_to_phonemes(text)
|
||||||
print(" > Phonemes: {}".format(ph))
|
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
if language == "ja-jp":
|
if language == "ja-jp":
|
||||||
ph = japanese_text_to_phonemes(text)
|
ph = japanese_text_to_phonemes(text)
|
||||||
print(" > Phonemes: {}".format(ph))
|
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
if gruut.is_language_supported(language):
|
if gruut.is_language_supported(language):
|
||||||
|
@ -80,7 +78,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
||||||
|
|
||||||
# Fix a few phonemes
|
# Fix a few phonemes
|
||||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||||
|
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||||
|
|
|
@ -2,8 +2,10 @@
|
||||||
# compatible with Julius https://github.com/julius-speech/segmentation-kit
|
# compatible with Julius https://github.com/julius-speech/segmentation-kit
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
import MeCab
|
import MeCab
|
||||||
|
from num2words import num2words
|
||||||
|
|
||||||
_CONVRULES = [
|
_CONVRULES = [
|
||||||
# Conversion of 2 letters
|
# Conversion of 2 letters
|
||||||
|
@ -373,8 +375,93 @@ def text2kata(text: str) -> str:
|
||||||
return hira2kata("".join(res))
|
return hira2kata("".join(res))
|
||||||
|
|
||||||
|
|
||||||
|
_ALPHASYMBOL_YOMI = {
|
||||||
|
"#": "シャープ",
|
||||||
|
"%": "パーセント",
|
||||||
|
"&": "アンド",
|
||||||
|
"+": "プラス",
|
||||||
|
"-": "マイナス",
|
||||||
|
":": "コロン",
|
||||||
|
";": "セミコロン",
|
||||||
|
"<": "小なり",
|
||||||
|
"=": "イコール",
|
||||||
|
">": "大なり",
|
||||||
|
"@": "アット",
|
||||||
|
"a": "エー",
|
||||||
|
"b": "ビー",
|
||||||
|
"c": "シー",
|
||||||
|
"d": "ディー",
|
||||||
|
"e": "イー",
|
||||||
|
"f": "エフ",
|
||||||
|
"g": "ジー",
|
||||||
|
"h": "エイチ",
|
||||||
|
"i": "アイ",
|
||||||
|
"j": "ジェー",
|
||||||
|
"k": "ケー",
|
||||||
|
"l": "エル",
|
||||||
|
"m": "エム",
|
||||||
|
"n": "エヌ",
|
||||||
|
"o": "オー",
|
||||||
|
"p": "ピー",
|
||||||
|
"q": "キュー",
|
||||||
|
"r": "アール",
|
||||||
|
"s": "エス",
|
||||||
|
"t": "ティー",
|
||||||
|
"u": "ユー",
|
||||||
|
"v": "ブイ",
|
||||||
|
"w": "ダブリュー",
|
||||||
|
"x": "エックス",
|
||||||
|
"y": "ワイ",
|
||||||
|
"z": "ゼット",
|
||||||
|
"α": "アルファ",
|
||||||
|
"β": "ベータ",
|
||||||
|
"γ": "ガンマ",
|
||||||
|
"δ": "デルタ",
|
||||||
|
"ε": "イプシロン",
|
||||||
|
"ζ": "ゼータ",
|
||||||
|
"η": "イータ",
|
||||||
|
"θ": "シータ",
|
||||||
|
"ι": "イオタ",
|
||||||
|
"κ": "カッパ",
|
||||||
|
"λ": "ラムダ",
|
||||||
|
"μ": "ミュー",
|
||||||
|
"ν": "ニュー",
|
||||||
|
"ξ": "クサイ",
|
||||||
|
"ο": "オミクロン",
|
||||||
|
"π": "パイ",
|
||||||
|
"ρ": "ロー",
|
||||||
|
"σ": "シグマ",
|
||||||
|
"τ": "タウ",
|
||||||
|
"υ": "ウプシロン",
|
||||||
|
"φ": "ファイ",
|
||||||
|
"χ": "カイ",
|
||||||
|
"ψ": "プサイ",
|
||||||
|
"ω": "オメガ",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
|
||||||
|
_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
|
||||||
|
_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
|
||||||
|
_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
|
||||||
|
|
||||||
|
|
||||||
|
def japanese_convert_numbers_to_words(text: str) -> str:
|
||||||
|
res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
|
||||||
|
res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
|
||||||
|
res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def japanese_convert_alpha_symbols_to_words(text: str) -> str:
|
||||||
|
return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
|
||||||
|
|
||||||
|
|
||||||
def japanese_text_to_phonemes(text: str) -> str:
|
def japanese_text_to_phonemes(text: str) -> str:
|
||||||
"""Convert Japanese text to phonemes."""
|
"""Convert Japanese text to phonemes."""
|
||||||
res = text2kata(text)
|
res = unicodedata.normalize("NFKC", text)
|
||||||
|
res = japanese_convert_numbers_to_words(res)
|
||||||
|
res = japanese_convert_alpha_symbols_to_words(res)
|
||||||
|
res = text2kata(res)
|
||||||
res = kata2phoneme(res)
|
res = kata2phoneme(res)
|
||||||
return res.replace(" ", "")
|
return res.replace(" ", "")
|
||||||
|
|
|
@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False):
|
||||||
|
"""Plot pitch curves on top of the spectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pitch (np.array): Pitch values.
|
||||||
|
spectrogram (np.array): Spectrogram values.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
pitch: :math:`(T,)`
|
||||||
|
spec: :math:`(C, T)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(spectrogram, torch.Tensor):
|
||||||
|
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||||
|
else:
|
||||||
|
spectrogram_ = spectrogram.T
|
||||||
|
spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||||
|
if ap is not None:
|
||||||
|
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
old_fig_size = plt.rcParams["figure.figsize"]
|
||||||
|
if fig_size is not None:
|
||||||
|
plt.rcParams["figure.figsize"] = fig_size
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
ax.imshow(spectrogram_, aspect="auto", origin="lower")
|
||||||
|
ax.set_xlabel("time")
|
||||||
|
ax.set_ylabel("spec_freq")
|
||||||
|
|
||||||
|
ax2 = ax.twinx()
|
||||||
|
ax2.plot(pitch, linewidth=5.0, color="red")
|
||||||
|
ax2.set_ylabel("F0")
|
||||||
|
|
||||||
|
plt.rcParams["figure.figsize"] = old_fig_size
|
||||||
|
if not output_fig:
|
||||||
|
plt.close()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
alignment,
|
alignment,
|
||||||
postnet_output,
|
postnet_output,
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import Dict, Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyworld as pw
|
||||||
import scipy.io.wavfile
|
import scipy.io.wavfile
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
@ -10,8 +11,6 @@ from torch import nn
|
||||||
|
|
||||||
from TTS.tts.utils.data import StandardScaler
|
from TTS.tts.utils.data import StandardScaler
|
||||||
|
|
||||||
# import pyworld as pw
|
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
@ -623,17 +622,47 @@ class AudioProcessor(object):
|
||||||
return 0, pad
|
return 0, pad
|
||||||
return pad // 2, pad // 2 + pad % 2
|
return pad // 2, pad // 2 + pad % 2
|
||||||
|
|
||||||
### Compute F0 ###
|
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||||
# TODO: pw causes some dep issues
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
# def compute_f0(self, x):
|
|
||||||
# f0, t = pw.dio(
|
Args:
|
||||||
# x.astype(np.double),
|
x (np.ndarray): Waveform.
|
||||||
# fs=self.sample_rate,
|
|
||||||
# f0_ceil=self.mel_fmax,
|
Returns:
|
||||||
# frame_period=1000 * self.hop_length / self.sample_rate,
|
np.ndarray: Pitch.
|
||||||
# )
|
|
||||||
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
Examples:
|
||||||
# return f0
|
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||||
|
>>> from TTS.config import BaseAudioConfig
|
||||||
|
>>> from TTS.utils.audio import AudioProcessor
|
||||||
|
>>> conf = BaseAudioConfig(mel_fmax=8000)
|
||||||
|
>>> ap = AudioProcessor(**conf)
|
||||||
|
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
||||||
|
>>> pitch = ap.compute_f0(wav)
|
||||||
|
"""
|
||||||
|
f0, t = pw.dio(
|
||||||
|
x.astype(np.double),
|
||||||
|
fs=self.sample_rate,
|
||||||
|
f0_ceil=self.mel_fmax,
|
||||||
|
frame_period=1000 * self.hop_length / self.sample_rate,
|
||||||
|
)
|
||||||
|
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
||||||
|
# pad = int((self.win_length / self.hop_length) / 2)
|
||||||
|
# f0 = [0.0] * pad + f0 + [0.0] * pad
|
||||||
|
# f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
|
||||||
|
# f0 = np.array(f0, dtype=np.float32)
|
||||||
|
|
||||||
|
# f01, _, _ = librosa.pyin(
|
||||||
|
# x,
|
||||||
|
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||||
|
# fmax=self.mel_fmax,
|
||||||
|
# frame_length=self.win_length,
|
||||||
|
# sr=self.sample_rate,
|
||||||
|
# fill_na=0.0,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# spec = self.melspectrogram(x)
|
||||||
|
return f0
|
||||||
|
|
||||||
### Audio Processing ###
|
### Audio Processing ###
|
||||||
def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:
|
def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:
|
||||||
|
|
|
@ -232,7 +232,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# compute a new d_vector from the given clip.
|
# compute a new d_vector from the given clip.
|
||||||
if speaker_wav is not None:
|
if speaker_wav is not None:
|
||||||
speaker_embedding = self.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||||
|
|
||||||
use_gl = self.vocoder_model is None
|
use_gl = self.vocoder_model is None
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ def to_camel(text):
|
||||||
|
|
||||||
def setup_model(config: Coqpit):
|
def setup_model(config: Coqpit):
|
||||||
"""Load models directly from configuration."""
|
"""Load models directly from configuration."""
|
||||||
print(" > Vocoder Model: {}".format(config.model))
|
|
||||||
if "discriminator_model" in config and "generator_model" in config:
|
if "discriminator_model" in config and "generator_model" in config:
|
||||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||||
MyModel = getattr(MyModel, "GAN")
|
MyModel = getattr(MyModel, "GAN")
|
||||||
|
@ -28,6 +27,7 @@ def setup_model(config: Coqpit):
|
||||||
MyModel = getattr(MyModel, to_camel(config.model))
|
MyModel = getattr(MyModel, to_camel(config.model))
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise ValueError(f"Model {config.model} not exist!") from e
|
raise ValueError(f"Model {config.model} not exist!") from e
|
||||||
|
print(" > Vocoder Model: {}".format(config.model))
|
||||||
model = MyModel(config)
|
model = MyModel(config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@
|
||||||
|
|
||||||
models/glow_tts.md
|
models/glow_tts.md
|
||||||
models/vits.md
|
models/vits.md
|
||||||
|
models/fast_pitch.md
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from TTS.config import BaseAudioConfig, BaseDatasetConfig
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
|
from TTS.tts.configs import FastPitchConfig
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# init configs
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||||
|
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = FastPitchConfig(
|
||||||
|
run_name="fast_pitch_ljspeech",
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=32,
|
||||||
|
eval_batch_size=16,
|
||||||
|
num_loader_workers=8,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
compute_f0=True,
|
||||||
|
f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
use_espeak_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
print_step=50,
|
||||||
|
print_eval=False,
|
||||||
|
mixed_precision=False,
|
||||||
|
sort_by_audio_len=True,
|
||||||
|
max_seq_len=500000,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=[dataset_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute alignments
|
||||||
|
if not config.model_args.use_aligner:
|
||||||
|
manager = ModelManager()
|
||||||
|
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||||
|
# TODO: make compute_attention python callable
|
||||||
|
os.system(
|
||||||
|
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# train the model
|
||||||
|
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||||
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||||
|
trainer.fit()
|
|
@ -25,3 +25,4 @@ unidic-lite==1.0.8
|
||||||
# gruut+supported langs
|
# gruut+supported langs
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
||||||
fsspec>=2021.04.0
|
fsspec>=2021.04.0
|
||||||
|
pyworld
|
|
@ -1,12 +1,16 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from TTS.config import BaseDatasetConfig
|
||||||
from TTS.utils.generic_utils import get_cuda
|
from TTS.utils.generic_utils import get_cuda
|
||||||
|
|
||||||
|
|
||||||
def get_device_id():
|
def get_device_id():
|
||||||
use_cuda, _ = get_cuda()
|
use_cuda, _ = get_cuda()
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
GPU_ID = "0"
|
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "":
|
||||||
|
GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]
|
||||||
|
else:
|
||||||
|
GPU_ID = "0"
|
||||||
else:
|
else:
|
||||||
GPU_ID = ""
|
GPU_ID = ""
|
||||||
return GPU_ID
|
return GPU_ID
|
||||||
|
@ -30,3 +34,7 @@ def get_tests_output_path():
|
||||||
def run_cli(command):
|
def run_cli(command):
|
||||||
exit_status = os.system(command)
|
exit_status = os.system(command)
|
||||||
assert exit_status == 0, f" [!] command `{command}` failed."
|
assert exit_status == 0, f" [!] command `{command}` failed."
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_data_config():
|
||||||
|
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
||||||
|
|
|
@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
wavs = data[11]
|
wavs = data['waveform']
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
|
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
assert avg_length >= last_length
|
||||||
|
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||||
|
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
speaker_name = data[2]
|
speaker_name = data['speaker_names']
|
||||||
linear_input = data[3]
|
linear_input = data['linear']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
stop_target = data[6]
|
stop_target = data['stop_targets']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
|
@ -181,3 +181,10 @@ class TestAudio(unittest.TestCase):
|
||||||
mel_norm = ap.melspectrogram(wav)
|
mel_norm = ap.melspectrogram(wav)
|
||||||
mel_denorm = ap.denormalize(mel_norm)
|
mel_denorm = ap.denormalize(mel_norm)
|
||||||
assert abs(mel_reference - mel_denorm).max() < 1e-4
|
assert abs(mel_reference - mel_denorm).max() < 1e-4
|
||||||
|
|
||||||
|
def test_compute_f0(self): # pylint: disable=no-self-use
|
||||||
|
ap = AudioProcessor(**conf)
|
||||||
|
wav = ap.load_wav(WAV_FILE)
|
||||||
|
pitch = ap.compute_f0(wav)
|
||||||
|
mel = ap.melspectrogram(wav)
|
||||||
|
assert pitch.shape[0] == mel.shape[1]
|
||||||
|
|
|
@ -5,11 +5,13 @@ from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
|
||||||
_TEST_CASES = """
|
_TEST_CASES = """
|
||||||
どちらに行きますか?/dochiraniikimasuka?
|
どちらに行きますか?/dochiraniikimasuka?
|
||||||
今日は温泉に、行きます。/kyo:waoNseNni,ikimasu.
|
今日は温泉に、行きます。/kyo:waoNseNni,ikimasu.
|
||||||
「A」から「Z」までです。/AkaraZmadedesu.
|
「A」から「Z」までです。/e:karazeqtomadedesu.
|
||||||
そうですね!/so:desune!
|
そうですね!/so:desune!
|
||||||
クジラは哺乳類です。/kujirawahonyu:ruidesu.
|
クジラは哺乳類です。/kujirawahonyu:ruidesu.
|
||||||
ヴィディオを見ます。/bidioomimasu.
|
ヴィディオを見ます。/bidioomimasu.
|
||||||
ky o: w a o N s e N n i , i k i m a s u ./kyo:waoNseNni,ikimasu.
|
今日は8月22日です/kyo:wahachigatsuniju:ninichidesu
|
||||||
|
xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma
|
||||||
|
値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch as T
|
||||||
|
|
||||||
|
from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
|
||||||
|
class AveragePitchTests(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
pitch = T.rand(1, 1, 128)
|
||||||
|
|
||||||
|
durations = T.randint(1, 5, (1, 21))
|
||||||
|
coeff = 128.0 / durations.sum()
|
||||||
|
durations = T.round(durations * coeff)
|
||||||
|
diff = 128.0 - durations.sum()
|
||||||
|
durations[0, -1] += diff
|
||||||
|
durations = durations.long()
|
||||||
|
|
||||||
|
pitch_avg = average_pitch(pitch, durations)
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
for idx, dur in enumerate(durations[0]):
|
||||||
|
assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
|
||||||
|
index += dur
|
||||||
|
|
||||||
|
|
||||||
|
def expand_encoder_outputs_test():
|
||||||
|
model = FastPitch(FastPitchArgs(num_chars=10))
|
||||||
|
|
||||||
|
inputs = T.rand(2, 5, 57)
|
||||||
|
durations = T.randint(1, 4, (2, 57))
|
||||||
|
|
||||||
|
x_mask = T.ones(2, 1, 57)
|
||||||
|
y_mask = T.ones(2, 1, durations.sum(1).max())
|
||||||
|
|
||||||
|
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
|
||||||
|
|
||||||
|
for b in range(durations.shape[0]):
|
||||||
|
index = 0
|
||||||
|
for idx, dur in enumerate(durations[b]):
|
||||||
|
diff = (
|
||||||
|
expanded[b, :, index : index + dur.item()]
|
||||||
|
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
|
||||||
|
).sum()
|
||||||
|
assert abs(diff) < 1e-6, diff
|
||||||
|
index += dur
|
Loading…
Reference in New Issue