Merge pull request #793 from coqui-ai/dev

v0.2.2
This commit is contained in:
Eren Gölge 2021-09-06 19:25:34 +02:00 committed by GitHub
commit dc2ace3ca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1747 additions and 180 deletions

View File

@ -4,14 +4,14 @@
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
[![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/)
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)
[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
[![Downloads](https://pepy.tech/badge/tts)](https://pepy.tech/project/tts)
[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440)
[![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/)
[![Gitter](https://badges.gitter.im/coqui-ai/TTS.svg)](https://gitter.im/coqui-ai/TTS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
📰 [**Subscribe to 🐸Coqui.ai Newsletter**](https://coqui.ai/?subscription=true)
@ -151,3 +151,5 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|- vocoder/ (Vocoder models.)
|- (same)
```
<img src="https://static.scarf.sh/a.png?x-pxid=503c242f-a253-4fb8-8071-ce1dc1e89999" />

View File

@ -62,7 +62,16 @@
"default_vocoder": null,
"commit": "3900448",
"author": "Eren Gölge @erogol",
"license": "",
"license": "TBD",
"contact": "egolge@coqui.com"
},
"fast_pitch": {
"description": "FastPitch model trained on LJSpeech using the Aligner Network",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.2/tts_models--en--ljspeech--fast_pitch.zip",
"default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2",
"commit": "b27b3ba",
"author": "Eren Gölge @erogol",
"license": "TBD",
"contact": "egolge@coqui.com"
}
},

View File

@ -1 +1 @@
0.2.1
0.2.2

View File

@ -8,12 +8,12 @@ import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.utils.io import load_checkpoint
if __name__ == "__main__":
# pylint: disable=bad-option-value
@ -27,7 +27,7 @@ Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
--dataset_metafile metadata.csv
--data_path /root/LJSpeech-1.1/
--batch_size 32
--dataset ljspeech
@ -76,8 +76,7 @@ Example run:
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker
model = setup_model(C)
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
model.eval()
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
# data loader
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
@ -127,9 +126,9 @@ Example run:
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
model_outputs = model.forward(text_input, text_lengths, mel_input)
alignments = alignments.detach()
alignments = model_outputs["alignments"].detach()
for idx, alignment in enumerate(alignments):
item_idx = item_idxs[idx]
# interpolate if r > 1
@ -149,10 +148,12 @@ Example run:
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
file_path = item_idx.replace(wav_file_name, align_file_name)
# save output
file_paths.append([item_idx, file_path])
wav_file_abs_path = os.path.abspath(item_idx)
file_abs_path = os.path.abspath(file_path)
file_paths.append([wav_file_abs_path, file_abs_path])
np.save(file_path, alignment)
# ourput metafile

View File

@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data[0]
text_lengths = data[1]
mel_input = data[4]
mel_lengths = data[5]
item_idx = data[7]
d_vectors = data[8]
speaker_ids = data[9]
attn_mask = data[10]
text_input = data['text']
text_lengths = data['text_lengths']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
item_idx = data['item_idxs']
d_vectors = data['d_vectors']
speaker_ids = data['speaker_ids']
attn_mask = data['attns']
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
@ -132,9 +132,8 @@ def inference(
speaker_c = speaker_ids
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()

View File

@ -12,60 +12,89 @@ class BaseAudioConfig(Coqpit):
Args:
fft_size (int):
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
win_length (int):
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
```fft_size```. Defaults to 1024.
hop_length (int):
Number of audio samples between adjacent STFT columns. Defaults to 1024.
frame_shift_ms (int):
Set ```hop_length``` based on milliseconds and sampling rate.
frame_length_ms (int):
Set ```win_length``` based on milliseconds and sampling rate.
stft_pad_mode (str):
Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
sample_rate (int):
Audio sampling rate. Defaults to 22050.
resample (bool):
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
preemphasis (float):
Preemphasis coefficient. Defaults to 0.0.
ref_level_db (int): 20
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
Defaults to 20.
do_sound_norm (bool):
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
log_func (str):
Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
do_trim_silence (bool):
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
power (float):
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
artifacts in the synthesized voice. Defaults to 1.5.
griffin_lim_iters (int):
Number of Griffing Lim iterations. Defaults to 60.
num_mels (int):
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
It needs to be adjusted for a dataset. Defaults to 0.
mel_fmax (float):
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
spec_gain (int):
Gain applied when converting amplitude to DB. Defaults to 20.
signal_norm (bool):
enable/disable signal normalization. Defaults to True.
min_level_db (int):
minimum db threshold for the computed melspectrograms. Defaults to -100.
symmetric_norm (bool):
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
[0, k], Defaults to True.
max_norm (float):
```k``` defining the normalization range. Defaults to 4.0.
clip_norm (bool):
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
stats_path (str):
Path to the computed stats file. Defaults to None.
"""
@ -147,15 +176,20 @@ class BaseDatasetConfig(Coqpit):
Args:
name (str):
Dataset name that defines the preprocessor in use. Defaults to None.
path (str):
Root path to the dataset files. Defaults to None.
meta_file_train (str):
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
Defaults to None.
unused_speakers (List):
List of speakers IDs that are not used at the training. Default None.
meta_file_val (str):
Name of the dataset meta file that defines the instances used at validation.
meta_file_attn_mask (str):
Path to the file that lists the attention mask files used with models that require attention masks to
train the duration predictor.
@ -298,7 +332,7 @@ class BaseTrainingConfig(Coqpit):
keep_all_best: bool = False
keep_after: int = 10000
# dataloading
num_loader_workers: int = None
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
# paths

View File

@ -271,6 +271,14 @@ class Trainer:
# setup scheduler
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
if self.scheduler is not None:
if self.args.continue_path:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
scheduler.last_epoch = self.restore_step
else:
self.scheduler.last_epoch = self.restore_step
# DISTRUBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
@ -291,7 +299,6 @@ class Trainer:
Returns:
nn.Module: initialized model.
"""
# TODO: better model setup
try:
model = setup_vocoder_model(config)
except ModuleNotFoundError:

View File

@ -0,0 +1,122 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.fast_pitch import FastPitchArgs
@dataclass
class FastPitchConfig(BaseTTSConfig):
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
Example:
>>> from TTS.tts.configs import FastPitchConfig
>>> config = FastPitchConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
lr (float):
Initial learning rate. Defaults to `1e-3`.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
huber_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "fast_pitch"
# model specific params
model_args: FastPitchArgs = field(default_factory=FastPitchArgs)
# multi-speaker settings
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
# loss params
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
# dataset configs
compute_f0: bool = True
f0_cache_path: str = None
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

View File

@ -103,6 +103,7 @@ class BaseTTSConfig(BaseTrainingConfig):
"""Shared parameters among all the tts models.
Args:
audio (BaseAudioConfig):
Audio processor config object instance.
@ -140,11 +141,14 @@ class BaseTTSConfig(BaseTrainingConfig):
loss_masking (bool):
enable / disable masking loss values against padded segments of samples in a batch.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int):
Minimum input sequence length to be used at training.
Minimum sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
compute_f0 (int):
(Not in use yet).
@ -197,6 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
batch_group_size: int = 0
loss_masking: bool = None
# dataloading
sort_by_audio_len: bool = False
min_seq_len: int = 1
max_seq_len: int = float("inf")
compute_f0: bool = False

View File

@ -67,11 +67,14 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum text length to be considered for training. Defaults to `13`.
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum text length to be considered for training. Defaults to `500`.
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.

View File

@ -22,6 +22,8 @@ class TTSDataset(Dataset):
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[List],
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
@ -40,8 +42,7 @@ class TTSDataset(Dataset):
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can either override or create a new class as the dataset is
initialized by the model.
If you need something different, you can inherit and override.
Args:
outputs_per_step (int): Number of time frames predicted per step.
@ -54,6 +55,10 @@ class TTSDataset(Dataset):
meta_data (list): List of dataset instances.
compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
@ -78,8 +83,8 @@ class TTSDataset(Dataset):
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
the coming iterations. Defaults to None.
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
@ -103,6 +108,8 @@ class TTSDataset(Dataset):
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
@ -119,8 +126,12 @@ class TTSDataset(Dataset):
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if compute_f0:
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Use phonemes: {}".format(self.use_phonemes))
@ -236,10 +247,16 @@ class TTSDataset(Dataset):
# TODO: find a better fix
return self.load_data(self.rescue_item_idx)
pitch = None
if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = {
"raw_text": raw_text,
"text": text,
"wav": wav,
"pitch": pitch,
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
@ -256,8 +273,8 @@ class TTSDataset(Dataset):
return phonemes
def compute_input_seq(self, num_workers=0):
"""compute input sequences separately. Call it before
passing dataset to data loader."""
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
@ -339,6 +356,11 @@ class TTSDataset(Dataset):
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
if len(new_items) == 0:
raise RuntimeError(" [!] No items left after filtering.")
# update items to the new sorted items
self.items = new_items
# logging
@ -359,11 +381,23 @@ class TTSDataset(Dataset):
def __getitem__(self, idx):
return self.load_data(idx)
@staticmethod
def _sort_batch(batch, text_lengths):
"""Sort the batch by the input text length for RNN efficiency.
Args:
batch (Dict): Batch returned by `__getitem__`.
text_lengths (List[int]): Lengths of the input character sequences.
"""
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
batch = [batch[idx] for idx in ids_sorted_decreasing]
return batch, text_lengths, ids_sorted_decreasing
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to Spectrograms.
2. Convert Audio signal to features.
3. PAD sequences wrt r.
4. Load to Torch.
"""
@ -371,30 +405,27 @@ class TTSDataset(Dataset):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.abc.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch])
text_lengths = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel_lengths = [m.shape[1] for m in mel]
@ -413,7 +444,7 @@ class TTSDataset(Dataset):
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32)
text = prepare_data(batch["text"]).astype(np.int32)
# PAD features with longest instance
mel = prepare_tensor(mel, self.outputs_per_step)
@ -422,7 +453,7 @@ class TTSDataset(Dataset):
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text_lengths = torch.LongTensor(text_lengths)
text = torch.LongTensor(text)
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
@ -436,7 +467,7 @@ class TTSDataset(Dataset):
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
@ -447,23 +478,33 @@ class TTSDataset(Dataset):
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# compute f0
# TODO: compare perf in collate_fn vs in load_data
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
# collate attention alignments
if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
if batch["attn"][0] is not None:
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step)
@ -471,21 +512,22 @@ class TTSDataset(Dataset):
else:
attns = None
# TODO: return dictionary
return (
text,
text_lenghts,
speaker_names,
linear,
mel,
mel_lengths,
stop_targets,
item_idxs,
d_vectors,
speaker_ids,
attns,
wav_padded,
raw_text,
)
return {
"text": text,
"text_lengths": text_lengths,
"speaker_names": batch["speaker_name"],
"linear": linear,
"mel": mel,
"mel_lengths": mel_lengths,
"stop_targets": stop_targets,
"item_idxs": batch["item_idx"],
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
"attns": attns,
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
}
raise TypeError(
(
@ -495,3 +537,110 @@ class TTSDataset(Dataset):
)
)
)
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.
"""
def __init__(
self,
items: List[List],
verbose=False,
):
self.items = items
self.verbose = verbose
self.mean = None
self.std = None
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
@staticmethod
def compute_pitch_stats(pitch_vecs):
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def normalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch = pitch - self.mean
pitch = pitch / self.std
pitch[zero_idxs] = 0.0
return pitch
def denormalize_pitch(self, pitch):
zero_idxs = np.where(pitch == 0.0)[0]
pitch *= self.std
pitch += self.mean
pitch[zero_idxs] = 0.0
return pitch
@staticmethod
def load_or_compute_pitch(ap, wav_file, cache_path):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
_, wav_file, *_ = item
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
return pitch
return None
def compute_pitch(self, ap, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
pitch_vecs = []
for _, item in enumerate(tqdm.tqdm(self.items)):
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
else:
with Pool(num_workers) as p:
pitch_vecs = list(
tqdm.tqdm(
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
def load_pitch_stats(self, cache_path):
stats_path = os.path.join(cache_path, "pitch_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)

View File

@ -1,6 +1,7 @@
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
@ -30,7 +31,17 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True):
def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
Args:
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True.
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
for dataset in datasets:
@ -51,7 +62,7 @@ def load_meta_data(datasets, eval_split=True):
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
# load attention masks for the duration predictor training
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):

View File

@ -0,0 +1,81 @@
from typing import Tuple
import torch
from torch import nn
class AlignmentNetwork(torch.nn.Module):
"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention.
::
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
key -> conv1d -> relu -> conv1d -----------------------^
Args:
in_query_channels (int): Number of channels in the query network. Defaults to 80.
in_key_channels (int): Number of channels in the key network. Defaults to 512.
attn_channels (int): Number of inner channels in the attention layers. Defaults to 80.
temperature (float): Temperature for the softmax. Defaults to 0.0005.
"""
def __init__(
self,
in_query_channels=80,
in_key_channels=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_layer = nn.Sequential(
nn.Conv1d(
in_key_channels,
in_key_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
)
self.query_layer = nn.Sequential(
nn.Conv1d(
in_query_channels,
in_query_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
)
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder.
Shapes:
- queries: :math:`[B, C, T_de]`
- keys: :math:`[B, C_emb, T_en]`
- mask: :math:`[B, T_de]`
Output:
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
"""
key_out = self.key_layer(keys)
query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
if attn_prior is not None:
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8)
if mask is not None:
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
return attn, attn_logp

View File

@ -7,17 +7,23 @@ from torch import nn
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
Args:
channels (int): embedding size
dropout (float): dropout parameter
dropout_p (float): dropout rate applied to the output.
max_len (int): maximum sequence length.
use_scale (bool): whether to use a learnable scaling coefficient.
"""
def __init__(self, channels, dropout_p=0.0, max_len=5000):
def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False):
super().__init__()
if channels % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
)
self.use_scale = use_scale
if use_scale:
self.scale = torch.nn.Parameter(torch.ones(1))
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
@ -49,9 +55,15 @@ class PositionalEncoding(nn.Module):
pos_enc = self.pe[:, :, : x.size(2)] * mask
else:
pos_enc = self.pe[:, :, : x.size(2)]
x = x + pos_enc
if self.use_scale:
x = x + self.scale * pos_enc
else:
x = x + pos_enc
else:
x = x + self.pe[:, :, first_idx:last_idx]
if self.use_scale:
x = x + self.scale * self.pe[:, :, first_idx:last_idx]
else:
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, "dropout"):
x = self.dropout(x)
return x

View File

@ -15,17 +15,19 @@ class FFTransformer(nn.Module):
self.norm1 = nn.LayerNorm(in_out_channels)
self.norm2 = nn.LayerNorm(in_out_channels)
self.dropout = nn.Dropout(dropout_p)
self.dropout1 = nn.Dropout(dropout_p)
self.dropout2 = nn.Dropout(dropout_p)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""😦 ugly looking with all the transposing"""
src = src.permute(2, 0, 1)
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src + src2)
# T x B x D -> B x D x T
src = src.permute(1, 2, 0)
src2 = self.conv2(F.relu(self.conv1(src)))
src2 = self.dropout(src2)
src2 = self.dropout2(src2)
src = src + src2
src = src.transpose(1, 2)
src = self.norm2(src)
@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module):
"""
TODO: handle multi-speaker
Shapes:
x: [B, C, T]
mask: [B, 1, T] or [B, T]
- x: :math:`[B, C, T]`
- mask: :math:`[B, 1, T] or [B, T]`
"""
if mask is not None and mask.ndim == 3:
mask = mask.squeeze(1)
@ -65,3 +67,21 @@ class FFTransformerBlock(nn.Module):
alignments.append(align.unsqueeze(1))
alignments = torch.cat(alignments, 1)
return x
class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
"""
Shapes:
- x: :math:`[B, C, T]`
- mask: :math:`[B, 1, T]`
TODO: Handle the cond input
"""
x = self.fft(x, mask=mask)
x = self.proj(x)
return x

View File

@ -21,8 +21,10 @@ def convert_pad_shape(pad_shape):
def generate_path(duration, mask):
"""
duration: [b, t_x]
mask: [b, t_x, t_y]
Shapes:
- duration: :math:`[B, T_en]`
- mask: :math:'[B, T_en, T_de]`
- path: :math:`[B, T_en, T_de]`
"""
device = duration.device
b, t_x, t_y = mask.shape
@ -45,8 +47,9 @@ def maximum_path(value, mask):
def maximum_path_cython(value, mask):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
Shapes:
- value: :math:`[B, T_en, T_de]`
- mask: :math:`[B, T_en, T_de]`
"""
value = value * mask
device = value.device

View File

@ -69,9 +69,9 @@ class MSELossMasked(nn.Module):
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Shapes:
x: B x T X D
target: B x T x D
length: B
- x: :math:`[B, T, D]`
- target: :math:`[B, T, D]`
- length: :math:`B`
Returns:
loss: An average loss value in range [0, 1] masked by the length.
"""
@ -658,3 +658,109 @@ class VitsDiscriminatorLoss(nn.Module):
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
return return_dict
class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
class FastPitchLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.spec_loss = MSELossMasked(False)
self.ssim = SSIMLoss()
self.dur_loss = MSELossMasked(False)
self.pitch_loss = MSELossMasked(False)
if c.model_args.use_aligner:
self.aligner_loss = ForwardSumLoss()
self.spec_loss_alpha = c.spec_loss_alpha
self.ssim_loss_alpha = c.ssim_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
@staticmethod
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
def forward(
self,
decoder_output,
decoder_target,
decoder_output_lens,
dur_output,
dur_target,
pitch_output,
pitch_target,
input_lens,
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
):
loss = 0
return_dict = {}
if self.ssim_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
loss = loss + self.ssim_loss_alpha * ssim_loss
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
if self.spec_loss_alpha > 0:
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
loss = loss + self.spec_loss_alpha * spec_loss
return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
if self.dur_loss_alpha > 0:
log_dur_tgt = torch.log(dur_target.float() + 1)
dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
loss = loss + self.dur_loss_alpha * dur_loss
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
if self.pitch_loss_alpha > 0:
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss
return return_dict

View File

@ -13,7 +13,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@ -355,9 +354,6 @@ class AlignTTS(BaseTTS):
phase=self.phase,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(

View File

@ -78,7 +78,9 @@ class BaseTacotron(BaseTTS):
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
return None
#############################
# INIT FUNCTIONS

View File

@ -104,18 +104,19 @@ class BaseTTS(BaseModel):
Dict: [description]
"""
# setup input batch
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3]
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
item_idx = batch[7]
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
waveform = batch[11]
text_input = batch["text"]
text_lengths = batch["text_lengths"]
speaker_names = batch["speaker_names"]
linear_input = batch["linear"]
mel_input = batch["mel"]
mel_lengths = batch["mel_lengths"]
stop_targets = batch["stop_targets"]
item_idx = batch["item_idxs"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -162,6 +163,7 @@ class BaseTTS(BaseModel):
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
}
def get_data_loader(
@ -199,6 +201,8 @@ class BaseTTS(BaseModel):
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items,
ap=ap,
characters=config.characters,
@ -245,6 +249,21 @@ class BaseTTS(BaseModel):
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -0,0 +1,697 @@
from dataclasses import dataclass, field
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.utils.audio import AudioProcessor
@dataclass
class FastPitchArgs(Coqpit):
"""Fast Pitch Model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 80.
hidden_channels (int):
Number of base hidden channels of the model. Defaults to 512.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256.
duration_predictor_dropout_p (float):
Dropout rate for the duration predictor. Defaults to 0.1.
duration_predictor_kernel_size (int):
Kernel size of conv layers in the duration predictor. Defaults to 3.
pitch_predictor_hidden_channels (int):
Number of hidden channels in the pitch predictor. Defaults to 256.
pitch_predictor_dropout_p (float):
Dropout rate for the pitch predictor. Defaults to 0.1.
pitch_predictor_kernel_size (int):
Kernel size of conv layers in the pitch predictor. Defaults to 3.
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
positional_encoding (bool):
Whether to use positional encoding. Defaults to True.
positional_encoding_use_scale (bool):
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
length_scale (int):
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
encoder_type (str):
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
Defaults to `fftransformer` as in the paper.
encoder_params (dict):
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
decoder_type (str):
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
Defaults to `fftransformer` as in the paper.
decoder_params (str):
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
use_d_vetor (bool):
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of channels of the d-vectors. Defaults to 0.
detach_duration_predictor (bool):
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
does not pass to the earlier layers. Defaults to True.
max_duration (int):
Maximum duration accepted by the model. Defaults to 75.
use_aligner (bool):
Use aligner network to learn the text to speech alignment. Defaults to True.
"""
num_chars: int = None
out_channels: int = 80
hidden_channels: int = 384
num_speakers: int = 0
duration_predictor_hidden_channels: int = 256
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True
length_scale: int = 1
encoder_type: str = "fftransformer"
encoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
decoder_type: str = "fftransformer"
decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
use_d_vector: bool = False
d_vector_dim: int = 0
detach_duration_predictor: bool = False
max_duration: int = 75
use_aligner: bool = True
class FastPitch(BaseTTS):
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
Paper::
https://arxiv.org/abs/2006.06873
Paper abstract::
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
frequency contours. The model predicts pitch contours during inference. By altering these predictions,
the generated speech can be more expressive, better match the semantic of the utterance, and in the end
more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech
that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall
quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead,
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
factor for mel-spectrogram synthesis of a typical utterance."
Args:
config (Coqpit): Model coqpit class.
Examples:
>>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs
>>> config = FastPitchArgs()
>>> model = FastPitch(config)
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
# don't use isintance not to import recursively
if config.__class__.__name__ == "FastPitchConfig":
if "characters" in config:
# loading from FasrPitchConfig
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
self.args = self.config.model_args
else:
# loading from FastPitchArgs
self.config = config
self.args = config.model_args
elif isinstance(config, FastPitchArgs):
self.args = config
self.config = config
else:
raise ValueError("config must be either a VitsConfig or Vitsself.args")
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_binary_alignment_loss = False
self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
)
self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
self.encoder = Encoder(
self.args.hidden_channels,
self.args.hidden_channels,
self.args.encoder_type,
self.args.encoder_params,
self.args.d_vector_dim,
)
if self.args.positional_encoding:
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
self.decoder = Decoder(
self.args.out_channels,
self.args.hidden_channels,
self.args.decoder_type,
self.args.decoder_params,
)
self.duration_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim,
self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p,
)
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.num_speakers > 1 and not self.args.use_d_vector:
# speaker embedding layer
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if self.args.use_aligner:
self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
)
@staticmethod
def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the durations.
Shapes
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
"""
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
expand encoder outputs
Shapes
- en: :math:`(B, D_{en}, T_{en})`
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
Examples:
- encoder output: :math:`[a,b,c,d]`
- durations: :math:`[1, 3, 2, 1]`
- expanded: :math:`[a, b, b, b, c, c, d]`
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
"""
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
"""Format predicted durations.
1. Convert to linear scale from log scale
2. Apply the length scale for speed adjustment
3. Apply masking.
4. Cast 0 durations to 1.
5. Round the duration values.
Args:
o_dr_log: Log scale durations.
x_mask: Input text mask.
Shapes:
- o_dr_log: :math:`(B, T_{de})`
- x_mask: :math:`(B, T_{en})`
"""
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0
o_dr = torch.round(o_dr)
return o_dr
@staticmethod
def _concat_speaker_embedding(o_en, g):
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
o_en = torch.cat([o_en, g_exp], 1)
return o_en
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder(
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass.
1. Embed speaker IDs if multi-speaker mode.
2. Embed character sequences.
3. Run the encoder network.
4. Concat speaker embedding to the encoder output for the duration predictor.
Args:
x (torch.LongTensor): Input sequence IDs.
x_mask (torch.FloatTensor): Input squence mask.
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
Returns:
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
character embeddings
Shapes:
- x: :math:`(B, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- g: :math:`(B, C)`
"""
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
g = g.unsqueeze(-1)
# [B, T, C]
x_emb = self.emb(x)
# encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# speaker conditioning for duration predictor
if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g)
else:
o_en_dp = o_en
return o_en, o_en_dp, x_mask, g, x_emb
def _forward_decoder(
self,
o_en: torch.FloatTensor,
dr: torch.IntTensor,
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Compute the decoder output mask
2. Expand encoder output with the durations.
3. Apply position encoding.
4. Add speaker embeddings if multi-speaker mode.
5. Run the decoder.
Args:
o_en (torch.FloatTensor): Encoder output.
dr (torch.IntTensor): Ground truth durations or alignment network durations.
x_mask (torch.IntTensor): Input sequence mask.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
# decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2)
def _forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average pitch values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(o_en, x_mask)
if pitch is not None:
avg_pitch = average_pitch(pitch, dr)
o_pitch_emb = self.pitch_emb(avg_pitch)
return o_pitch_emb, o_pitch, avg_pitch
o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch
def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
2. Run the alignment network.
3. Apply MAS to compute the hard alignment map.
4. Compute the durations from the hard alignment map.
Args:
x (torch.FloatTensor): Input sequence.
y (torch.FloatTensor): Output sequence.
x_mask (torch.IntTensor): Input sequence mask.
y_mask (torch.IntTensor): Output sequence mask.
Returns:
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
hard alignment map.
Shapes:
- x: :math:`[B, T_en, C_en]`
- y: :math:`[B, T_de, C_de]`
- x_mask: :math:`[B, 1, T_en]`
- y_mask: :math:`[B, 1, T_de]`
- o_alignment_dur: :math:`[B, T_en]`
- alignment_soft: :math:`[B, T_en, T_de]`
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
- alignment_mas: :math:`[B, T_en, T_de]`
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
)
o_alignment_dur = torch.sum(alignment_mas, -1).int()
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
def forward(
self,
x: torch.LongTensor,
x_lengths: torch.LongTensor,
y_lengths: torch.LongTensor,
y: torch.FloatTensor = None,
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict:
"""Model's forward pass.
Args:
x (torch.LongTensor): Input character sequences.
x_lengths (torch.LongTensor): Input sequence lengths.
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
y (torch.FloatTensor): Spectrogram frames. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes:
- x: :math:`[B, T_max]`
- x_lengths: :math:`[B]`
- y_lengths: :math:`[B]`
- y: :math:`[B, T_max2]`
- dr: :math:`[B, T_max]`
- g: :math:`[B, C]`
- pitch: :math:`[B, 1, T]`
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
# compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype)
# encoder pass
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
# duration predictor pass
if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
else:
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
# aligner pass
if self.use_aligner:
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
x_emb, y, x_mask, y_mask
)
dr = o_alignment_dur
# pitch predictor pass
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de,
"durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization
"pitch_avg": o_pitch,
"pitch_avg_gt": avg_pitch,
"alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
"x_mask": x_mask,
"y_mask": y_mask,
}
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""Model's inference pass.
Args:
x (torch.LongTensor): Input character sequence.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
Shapes:
- x: [B, T_max]
- x_lengths: [B]
- g: [B, C]
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
# encoder pass
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
# pitch predictor pass
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
o_en = o_en + o_pitch_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de,
"alignments": attn,
"pitch": o_pitch,
"durations_log": o_dr_log,
}
return outputs
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
pitch = batch["pitch"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
# forward pass
outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
)
# use aligner's output as the duration target
if self.use_aligner:
durations = outputs["o_alignment_dur"]
# use float32 in AMP
with autocast(enabled=False):
# compute loss
loss_dict = criterion(
decoder_output=outputs["model_outputs"],
decoder_target=mel_input,
decoder_output_lens=mel_lengths,
dur_output=outputs["durations_log"],
dur_target=durations,
pitch_output=outputs["pitch_avg"],
pitch_target=outputs["pitch_avg_gt"],
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"],
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
)
# compute duration error
durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
}
# plot the attention mask computed from the predicted durations
if "attn_durations" in outputs:
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
def get_criterion(self):
from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel
return FastPitchLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
def average_pitch(pitch, durs):
"""Compute the average pitch value for each input character based on the durations.
Shapes:
- pitch: :math:`[B, 1, T_de]`
- durs: :math:`[B, T_en]`
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0))
bs, l = durs_cums_ends.size()
n_formants = pitch.size(1)
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)
pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float()
pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float()
pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems)
return pitch_avg

View File

@ -10,7 +10,6 @@ from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -110,6 +109,10 @@ class GlowTTS(BaseTTS):
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
if config.use_d_vector_file:
self.external_d_vector_dim = config.d_vector_dim
else:
self.external_d_vector_dim = 0
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = self.c_in_channels
@ -130,22 +133,22 @@ class GlowTTS(BaseTTS):
return y_mean, y_log_scale, o_attn_dur
def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
- x: :math:`[B, T]`
- x_lenghts::math:` B`
- x_lenghts::math:`B`
- y: :math:`[B, T, C]`
- y_lengths::math:` B`
- y_lengths::math:`B`
- g: :math:`[B, C] or B`
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.d_vector_dim:
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
@ -182,7 +185,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
@ -199,12 +202,11 @@ class GlowTTS(BaseTTS):
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.external_d_vector_dim:
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
@ -244,7 +246,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None}
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -276,7 +278,7 @@ class GlowTTS(BaseTTS):
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
@ -327,8 +329,9 @@ class GlowTTS(BaseTTS):
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
loss_dict = criterion(
outputs["model_outputs"],
@ -341,9 +344,6 @@ class GlowTTS(BaseTTS):
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use

View File

@ -41,9 +41,9 @@ def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4
x_lengths = T
max_idxs = x_lengths - segment_size + 1
assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, ids_str, segment_size)
return ret, ids_str
segment_indices = (torch.rand([B]).type_as(x) * max_idxs).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices
@dataclass

View File

@ -7,7 +7,7 @@ def alignment_diagonal_score(alignments, binary=False):
binary (bool): if True, ignore scores and consider attention
as a binary mask.
Shape:
alignments : batch x decoder_steps x encoder_steps
- alignments : :math:`[B, T_de, T_en]`
"""
maxs = alignments.max(dim=1)[0]
if binary:

View File

@ -45,12 +45,10 @@ def text2phone(text, language, use_espeak_phonemes=False):
# TO REVIEW : How to have a good implementation for this?
if language == "zh-CN":
ph = chinese_text_to_phonemes(text)
print(" > Phonemes: {}".format(ph))
return ph
if language == "ja-jp":
ph = japanese_text_to_phonemes(text)
print(" > Phonemes: {}".format(ph))
return ph
if gruut.is_language_supported(language):
@ -80,7 +78,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE)
return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.")

View File

@ -2,8 +2,10 @@
# compatible with Julius https://github.com/julius-speech/segmentation-kit
import re
import unicodedata
import MeCab
from num2words import num2words
_CONVRULES = [
# Conversion of 2 letters
@ -373,8 +375,93 @@ def text2kata(text: str) -> str:
return hira2kata("".join(res))
_ALPHASYMBOL_YOMI = {
"#": "シャープ",
"%": "パーセント",
"&": "アンド",
"+": "プラス",
"-": "マイナス",
":": "コロン",
";": "セミコロン",
"<": "小なり",
"=": "イコール",
">": "大なり",
"@": "アット",
"a": "エー",
"b": "ビー",
"c": "シー",
"d": "ディー",
"e": "イー",
"f": "エフ",
"g": "ジー",
"h": "エイチ",
"i": "アイ",
"j": "ジェー",
"k": "ケー",
"l": "エル",
"m": "エム",
"n": "エヌ",
"o": "オー",
"p": "ピー",
"q": "キュー",
"r": "アール",
"s": "エス",
"t": "ティー",
"u": "ユー",
"v": "ブイ",
"w": "ダブリュー",
"x": "エックス",
"y": "ワイ",
"z": "ゼット",
"α": "アルファ",
"β": "ベータ",
"γ": "ガンマ",
"δ": "デルタ",
"ε": "イプシロン",
"ζ": "ゼータ",
"η": "イータ",
"θ": "シータ",
"ι": "イオタ",
"κ": "カッパ",
"λ": "ラムダ",
"μ": "ミュー",
"ν": "ニュー",
"ξ": "クサイ",
"ο": "オミクロン",
"π": "パイ",
"ρ": "ロー",
"σ": "シグマ",
"τ": "タウ",
"υ": "ウプシロン",
"φ": "ファイ",
"χ": "カイ",
"ψ": "プサイ",
"ω": "オメガ",
}
_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
_CURRENCY_MAP = {"$": "ドル", "¥": "", "£": "ポンド", "": "ユーロ"}
_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
def japanese_convert_numbers_to_words(text: str) -> str:
res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
return res
def japanese_convert_alpha_symbols_to_words(text: str) -> str:
return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
def japanese_text_to_phonemes(text: str) -> str:
"""Convert Japanese text to phonemes."""
res = text2kata(text)
res = unicodedata.normalize("NFKC", text)
res = japanese_convert_numbers_to_words(res)
res = japanese_convert_alpha_symbols_to_words(res)
res = text2kata(res)
res = kata2phoneme(res)
return res.replace(" ", "")

View File

@ -49,6 +49,46 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
return fig
def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False):
"""Plot pitch curves on top of the spectrogram.
Args:
pitch (np.array): Pitch values.
spectrogram (np.array): Spectrogram values.
Shapes:
pitch: :math:`(T,)`
spec: :math:`(C, T)`
"""
if isinstance(spectrogram, torch.Tensor):
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
else:
spectrogram_ = spectrogram.T
spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
if ap is not None:
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
old_fig_size = plt.rcParams["figure.figsize"]
if fig_size is not None:
plt.rcParams["figure.figsize"] = fig_size
fig, ax = plt.subplots()
ax.imshow(spectrogram_, aspect="auto", origin="lower")
ax.set_xlabel("time")
ax.set_ylabel("spec_freq")
ax2 = ax.twinx()
ax2.plot(pitch, linewidth=5.0, color="red")
ax2.set_ylabel("F0")
plt.rcParams["figure.figsize"] = old_fig_size
if not output_fig:
plt.close()
return fig
def visualize(
alignment,
postnet_output,

View File

@ -2,6 +2,7 @@ from typing import Dict, Tuple
import librosa
import numpy as np
import pyworld as pw
import scipy.io.wavfile
import scipy.signal
import soundfile as sf
@ -10,8 +11,6 @@ from torch import nn
from TTS.tts.utils.data import StandardScaler
# import pyworld as pw
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
@ -623,17 +622,47 @@ class AudioProcessor(object):
return 0, pad
return pad // 2, pad // 2 + pad % 2
### Compute F0 ###
# TODO: pw causes some dep issues
# def compute_f0(self, x):
# f0, t = pw.dio(
# x.astype(np.double),
# fs=self.sample_rate,
# f0_ceil=self.mel_fmax,
# frame_period=1000 * self.hop_length / self.sample_rate,
# )
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
# return f0
def compute_f0(self, x: np.ndarray) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform.
Returns:
np.ndarray: Pitch.
Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(mel_fmax=8000)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
>>> pitch = ap.compute_f0(wav)
"""
f0, t = pw.dio(
x.astype(np.double),
fs=self.sample_rate,
f0_ceil=self.mel_fmax,
frame_period=1000 * self.hop_length / self.sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
# pad = int((self.win_length / self.hop_length) / 2)
# f0 = [0.0] * pad + f0 + [0.0] * pad
# f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
# f0 = np.array(f0, dtype=np.float32)
# f01, _, _ = librosa.pyin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate,
# fill_na=0.0,
# )
# spec = self.melspectrogram(x)
return f0
### Audio Processing ###
def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:

View File

@ -232,7 +232,7 @@ class Synthesizer(object):
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.speaker_manager.compute_d_vector_from_clip(speaker_wav)
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
use_gl = self.vocoder_model is None

View File

@ -11,7 +11,6 @@ def to_camel(text):
def setup_model(config: Coqpit):
"""Load models directly from configuration."""
print(" > Vocoder Model: {}".format(config.model))
if "discriminator_model" in config and "generator_model" in config:
MyModel = importlib.import_module("TTS.vocoder.models.gan")
MyModel = getattr(MyModel, "GAN")
@ -28,6 +27,7 @@ def setup_model(config: Coqpit):
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
model = MyModel(config)
return model

View File

@ -45,6 +45,7 @@
models/glow_tts.md
models/vits.md
models/fast_pitch.md
.. toctree::
:maxdepth: 2

View File

@ -0,0 +1,70 @@
import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import FastPitchConfig
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = FastPitchConfig(
run_name="fast_pitch_ljspeech",
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=4,
compute_input_seq_cache=True,
compute_f0=True,
f0_cache_path=os.path.join(output_path, "f0_cache"),
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)
# compute alignments
if not config.model_args.use_aligner:
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
)
# train the model
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
trainer.fit()

View File

@ -25,3 +25,4 @@ unidic-lite==1.0.8
# gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
fsspec>=2021.04.0
pyworld

View File

@ -1,12 +1,16 @@
import os
from TTS.config import BaseDatasetConfig
from TTS.utils.generic_utils import get_cuda
def get_device_id():
use_cuda, _ = get_cuda()
if use_cuda:
GPU_ID = "0"
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "":
GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]
else:
GPU_ID = "0"
else:
GPU_ID = ""
return GPU_ID
@ -30,3 +34,7 @@ def get_tests_output_path():
def run_cli(command):
exit_status = os.system(command)
assert exit_status == 0, f" [!] command `{command}` failed."
def get_test_data_config():
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")

View File

@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
wavs = data[11]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
wavs = data['waveform']
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
if mel_lengths[0] > mel_lengths[1]:
idx = 0

View File

@ -181,3 +181,10 @@ class TestAudio(unittest.TestCase):
mel_norm = ap.melspectrogram(wav)
mel_denorm = ap.denormalize(mel_norm)
assert abs(mel_reference - mel_denorm).max() < 1e-4
def test_compute_f0(self): # pylint: disable=no-self-use
ap = AudioProcessor(**conf)
wav = ap.load_wav(WAV_FILE)
pitch = ap.compute_f0(wav)
mel = ap.melspectrogram(wav)
assert pitch.shape[0] == mel.shape[1]

View File

@ -5,11 +5,13 @@ from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes
_TEST_CASES = """
どちらに行きますか/dochiraniikimasuka?
今日は温泉に行きます/kyo:waoNseNni,ikimasu.
AからZまでです/AkaraZmadedesu.
AからZまでです/e:karazeqtomadedesu.
そうですね/so:desune!
クジラは哺乳類です/kujirawahonyu:ruidesu.
ヴィディオを見ます/bidioomimasu.
ky o: w a o N s e N n i , i k i m a s u ./kyo:waoNseNni,ikimasu.
今日は月22日です/kyo:wahachigatsuniju:ninichidesu
xyzとαβγ/eqkusuwaizeqtotoarufabe:tagaNma
値段は$12.34です/nedaNwaju:niteNsaNyoNdorudesu
"""

View File

@ -0,0 +1,47 @@
import unittest
import torch as T
from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs, average_pitch
# pylint: disable=unused-variable
class AveragePitchTests(unittest.TestCase):
def test_in_out(self): # pylint: disable=no-self-use
pitch = T.rand(1, 1, 128)
durations = T.randint(1, 5, (1, 21))
coeff = 128.0 / durations.sum()
durations = T.round(durations * coeff)
diff = 128.0 - durations.sum()
durations[0, -1] += diff
durations = durations.long()
pitch_avg = average_pitch(pitch, durations)
index = 0
for idx, dur in enumerate(durations[0]):
assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
index += dur
def expand_encoder_outputs_test():
model = FastPitch(FastPitchArgs(num_chars=10))
inputs = T.rand(2, 5, 57)
durations = T.randint(1, 4, (2, 57))
x_mask = T.ones(2, 1, 57)
y_mask = T.ones(2, 1, durations.sum(1).max())
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
for b in range(durations.shape[0]):
index = 0
for idx, dur in enumerate(durations[b]):
diff = (
expanded[b, :, index : index + dur.item()]
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
).sum()
assert abs(diff) < 1e-6, diff
index += dur