mirror of https://github.com/coqui-ai/TTS.git
Fastspeech2 (#2073)
* added EnergyDataset * add energy to Dataset * add comupte_energy * added energy params * added energy to forward_tts * added plot_avg_energy for visualisation * Update forward_tts.py * create file * added fastspeech2 recipe * add fastspeech2 config * removed energy from fast pitch * add energy loss to forward tts * Update fastspeech2_config.py * change run_name * Update numpy_transforms.py * fix typo * fix typo * fix typo * linting issues * use_energy default value --> False * Update numpy_transforms.py * linting fixes * fix typo * liniting_fix * liniting_fix * fix * fixes * fixes * lint fix * lint fixws * added training test * wrong import * wrong import * trailing whitespace * style fix * changed class name because of error * class name change * class name change * change class name * fixed styles
This commit is contained in:
parent
14d45b5347
commit
bc422f2f3c
|
@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0(bool):
|
||||||
|
Compute pitch. defaults to True
|
||||||
|
|
||||||
|
f0_cache_path(str):
|
||||||
|
pith cache path. defaults to None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "fast_pitch"
|
model: str = "fast_pitch"
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fastspeech2Config(BaseTTSConfig):
|
||||||
|
"""Configure `ForwardTTS` as FastPitch model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config
|
||||||
|
>>> config = FastSpeech2Config()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||||
|
|
||||||
|
base_model (str):
|
||||||
|
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||||
|
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||||
|
|
||||||
|
model_args (Coqpit):
|
||||||
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
|
||||||
|
data_dep_init_steps (int):
|
||||||
|
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||||
|
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||||
|
for the rest. Defaults to 10.
|
||||||
|
|
||||||
|
speakers_file (str):
|
||||||
|
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
|
||||||
|
speaker names. Defaults to `None`.
|
||||||
|
|
||||||
|
use_speaker_embedding (bool):
|
||||||
|
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||||
|
in the multi-speaker mode. Defaults to False.
|
||||||
|
|
||||||
|
use_d_vector_file (bool):
|
||||||
|
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_file (str):
|
||||||
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Dimension of the external speaker embeddings. Defaults to 0.
|
||||||
|
|
||||||
|
optimizer (str):
|
||||||
|
Name of the model optimizer. Defaults to `Adam`.
|
||||||
|
|
||||||
|
optimizer_params (dict):
|
||||||
|
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
|
||||||
|
|
||||||
|
lr_scheduler (str):
|
||||||
|
Name of the learning rate scheduler. Defaults to `Noam`.
|
||||||
|
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
|
||||||
|
|
||||||
|
lr (float):
|
||||||
|
Initial learning rate. Defaults to `1e-3`.
|
||||||
|
|
||||||
|
grad_clip (float):
|
||||||
|
Gradient norm clipping value. Defaults to `5.0`.
|
||||||
|
|
||||||
|
spec_loss_type (str):
|
||||||
|
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||||
|
|
||||||
|
duration_loss_type (str):
|
||||||
|
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||||
|
|
||||||
|
use_ssim_loss (bool):
|
||||||
|
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
|
||||||
|
|
||||||
|
wd (float):
|
||||||
|
Weight decay coefficient. Defaults to `1e-7`.
|
||||||
|
|
||||||
|
ssim_loss_alpha (float):
|
||||||
|
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
dur_loss_alpha (float):
|
||||||
|
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
spec_loss_alpha (float):
|
||||||
|
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
pitch_loss_alpha (float):
|
||||||
|
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||||
|
|
||||||
|
energy_loss_alpha (float):
|
||||||
|
Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_align_loss_alpha (float):
|
||||||
|
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_loss_warmup_epochs (float):
|
||||||
|
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
|
||||||
|
|
||||||
|
min_seq_len (int):
|
||||||
|
Minimum input sequence length to be used at training.
|
||||||
|
|
||||||
|
max_seq_len (int):
|
||||||
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0(bool):
|
||||||
|
Compute pitch. defaults to True
|
||||||
|
|
||||||
|
f0_cache_path(str):
|
||||||
|
pith cache path. defaults to None
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_energy(bool):
|
||||||
|
Compute energy. defaults to True
|
||||||
|
|
||||||
|
energy_cache_path(str):
|
||||||
|
energy cache path. defaults to None
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "fastspeech2"
|
||||||
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
|
# model specific params
|
||||||
|
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||||
|
|
||||||
|
# multi-speaker settings
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers_file: str = None
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: str = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
|
||||||
|
# optimizer parameters
|
||||||
|
optimizer: str = "Adam"
|
||||||
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
|
lr_scheduler: str = "NoamLR"
|
||||||
|
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||||
|
lr: float = 1e-4
|
||||||
|
grad_clip: float = 5.0
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
spec_loss_type: str = "mse"
|
||||||
|
duration_loss_type: str = "mse"
|
||||||
|
use_ssim_loss: bool = True
|
||||||
|
ssim_loss_alpha: float = 1.0
|
||||||
|
spec_loss_alpha: float = 1.0
|
||||||
|
aligner_loss_alpha: float = 1.0
|
||||||
|
pitch_loss_alpha: float = 0.1
|
||||||
|
energy_loss_alpha: float = 0.1
|
||||||
|
dur_loss_alpha: float = 0.1
|
||||||
|
binary_align_loss_alpha: float = 0.1
|
||||||
|
binary_loss_warmup_epochs: int = 150
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
min_seq_len: int = 13
|
||||||
|
max_seq_len: int = 200
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0: bool = True
|
||||||
|
f0_cache_path: str = None
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_energy: bool = True
|
||||||
|
energy_cache_path: str = None
|
||||||
|
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist.",
|
||||||
|
"Prior to November 22, 1963.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||||
|
if self.num_speakers > 0:
|
||||||
|
self.model_args.num_speakers = self.num_speakers
|
||||||
|
|
||||||
|
# speaker embedding settings
|
||||||
|
if self.use_speaker_embedding:
|
||||||
|
self.model_args.use_speaker_embedding = True
|
||||||
|
if self.speakers_file:
|
||||||
|
self.model_args.speakers_file = self.speakers_file
|
||||||
|
|
||||||
|
# d-vector settings
|
||||||
|
if self.use_d_vector_file:
|
||||||
|
self.model_args.use_d_vector_file = True
|
||||||
|
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||||
|
self.model_args.d_vector_dim = self.d_vector_dim
|
||||||
|
if self.d_vector_file:
|
||||||
|
self.model_args.d_vector_file = self.d_vector_file
|
|
@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
compute_f0 (int):
|
compute_f0 (int):
|
||||||
(Not in use yet).
|
(Not in use yet).
|
||||||
|
|
||||||
|
compute_energy (int):
|
||||||
|
(Not in use yet).
|
||||||
|
|
||||||
compute_linear_spec (bool):
|
compute_linear_spec (bool):
|
||||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||||
|
|
||||||
|
@ -312,6 +315,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
min_text_len: int = 1
|
min_text_len: int = 1
|
||||||
max_text_len: int = float("inf")
|
max_text_len: int = float("inf")
|
||||||
compute_f0: bool = False
|
compute_f0: bool = False
|
||||||
|
compute_energy: bool = False
|
||||||
compute_linear_spec: bool = False
|
compute_linear_spec: bool = False
|
||||||
precompute_num_workers: int = 0
|
precompute_num_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
|
||||||
|
|
||||||
# to prevent too many open files error as suggested here
|
# to prevent too many open files error as suggested here
|
||||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||||
|
@ -50,7 +51,9 @@ class TTSDataset(Dataset):
|
||||||
samples: List[Dict] = None,
|
samples: List[Dict] = None,
|
||||||
tokenizer: "TTSTokenizer" = None,
|
tokenizer: "TTSTokenizer" = None,
|
||||||
compute_f0: bool = False,
|
compute_f0: bool = False,
|
||||||
|
compute_energy: bool = False,
|
||||||
f0_cache_path: str = None,
|
f0_cache_path: str = None,
|
||||||
|
energy_cache_path: str = None,
|
||||||
return_wav: bool = False,
|
return_wav: bool = False,
|
||||||
batch_group_size: int = 0,
|
batch_group_size: int = 0,
|
||||||
min_text_len: int = 0,
|
min_text_len: int = 0,
|
||||||
|
@ -84,8 +87,12 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||||
|
|
||||||
|
compute_energy (bool): compute energy if True. Defaults to False.
|
||||||
|
|
||||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||||
|
|
||||||
|
energy_cache_path (str): Path to store energy cache. Defaults to None.
|
||||||
|
|
||||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||||
|
|
||||||
batch_group_size (int): Range of batch randomization after sorting
|
batch_group_size (int): Range of batch randomization after sorting
|
||||||
|
@ -128,7 +135,9 @@ class TTSDataset(Dataset):
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
self.return_wav = return_wav
|
self.return_wav = return_wav
|
||||||
self.compute_f0 = compute_f0
|
self.compute_f0 = compute_f0
|
||||||
|
self.compute_energy = compute_energy
|
||||||
self.f0_cache_path = f0_cache_path
|
self.f0_cache_path = f0_cache_path
|
||||||
|
self.energy_cache_path = energy_cache_path
|
||||||
self.min_audio_len = min_audio_len
|
self.min_audio_len = min_audio_len
|
||||||
self.max_audio_len = max_audio_len
|
self.max_audio_len = max_audio_len
|
||||||
self.min_text_len = min_text_len
|
self.min_text_len = min_text_len
|
||||||
|
@ -155,7 +164,10 @@ class TTSDataset(Dataset):
|
||||||
self.f0_dataset = F0Dataset(
|
self.f0_dataset = F0Dataset(
|
||||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||||
)
|
)
|
||||||
|
if compute_energy:
|
||||||
|
self.energy_dataset = EnergyDataset(
|
||||||
|
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
|
||||||
|
)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.print_logs()
|
self.print_logs()
|
||||||
|
|
||||||
|
@ -211,6 +223,12 @@ class TTSDataset(Dataset):
|
||||||
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
def get_energy(self, idx):
|
||||||
|
out_dict = self.energy_dataset[idx]
|
||||||
|
item = self.samples[idx]
|
||||||
|
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
||||||
|
return out_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_attn_mask(attn_file):
|
def get_attn_mask(attn_file):
|
||||||
return np.load(attn_file)
|
return np.load(attn_file)
|
||||||
|
@ -252,12 +270,16 @@ class TTSDataset(Dataset):
|
||||||
f0 = None
|
f0 = None
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
f0 = self.get_f0(idx)["f0"]
|
f0 = self.get_f0(idx)["f0"]
|
||||||
|
energy = None
|
||||||
|
if self.compute_energy:
|
||||||
|
energy = self.get_energy(idx)["energy"]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"raw_text": raw_text,
|
"raw_text": raw_text,
|
||||||
"token_ids": token_ids,
|
"token_ids": token_ids,
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"pitch": f0,
|
"pitch": f0,
|
||||||
|
"energy": energy,
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
"item_idx": item["audio_file"],
|
"item_idx": item["audio_file"],
|
||||||
"speaker_name": item["speaker_name"],
|
"speaker_name": item["speaker_name"],
|
||||||
|
@ -490,7 +512,13 @@ class TTSDataset(Dataset):
|
||||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||||
else:
|
else:
|
||||||
pitch = None
|
pitch = None
|
||||||
|
# format energy
|
||||||
|
if self.compute_energy:
|
||||||
|
energy = prepare_data(batch["energy"])
|
||||||
|
assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
|
||||||
|
energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
|
||||||
|
else:
|
||||||
|
energy = None
|
||||||
# format attention masks
|
# format attention masks
|
||||||
attns = None
|
attns = None
|
||||||
if batch["attn"][0] is not None:
|
if batch["attn"][0] is not None:
|
||||||
|
@ -519,6 +547,7 @@ class TTSDataset(Dataset):
|
||||||
"waveform": wav_padded,
|
"waveform": wav_padded,
|
||||||
"raw_text": batch["raw_text"],
|
"raw_text": batch["raw_text"],
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
|
"energy": energy,
|
||||||
"language_ids": language_ids,
|
"language_ids": language_ids,
|
||||||
"audio_unique_names": batch["audio_unique_name"],
|
"audio_unique_names": batch["audio_unique_name"],
|
||||||
}
|
}
|
||||||
|
@ -777,3 +806,155 @@ class F0Dataset:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"{indent}> F0Dataset ")
|
print(f"{indent}> F0Dataset ")
|
||||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyDataset:
|
||||||
|
"""Energy Dataset for computing Energy from wav files in CPU
|
||||||
|
|
||||||
|
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||||
|
also computes the mean and std of Energy values if `normalize_Energy` is True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (Union[List[List], List[Dict]]):
|
||||||
|
List of samples. Each sample is a list or a dict.
|
||||||
|
|
||||||
|
ap (AudioProcessor):
|
||||||
|
AudioProcessor to compute Energy from wav files.
|
||||||
|
|
||||||
|
cache_path (str):
|
||||||
|
Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
precompute_num_workers (int):
|
||||||
|
Number of workers used for pre-computing the Energy values. Defaults to 0.
|
||||||
|
|
||||||
|
normalize_Energy (bool):
|
||||||
|
Whether to normalize Energy values by mean and std. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samples: Union[List[List], List[Dict]],
|
||||||
|
ap: "AudioProcessor",
|
||||||
|
verbose=False,
|
||||||
|
cache_path: str = None,
|
||||||
|
precompute_num_workers=0,
|
||||||
|
normalize_energy=True,
|
||||||
|
):
|
||||||
|
self.samples = samples
|
||||||
|
self.ap = ap
|
||||||
|
self.verbose = verbose
|
||||||
|
self.cache_path = cache_path
|
||||||
|
self.normalize_energy = normalize_energy
|
||||||
|
self.pad_id = 0.0
|
||||||
|
self.mean = None
|
||||||
|
self.std = None
|
||||||
|
if cache_path is not None and not os.path.exists(cache_path):
|
||||||
|
os.makedirs(cache_path)
|
||||||
|
self.precompute(precompute_num_workers)
|
||||||
|
if normalize_energy:
|
||||||
|
self.load_stats(cache_path)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.samples[idx]
|
||||||
|
energy = self.compute_or_load(item["audio_file"])
|
||||||
|
if self.normalize_energy:
|
||||||
|
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||||
|
energy = self.normalize(energy)
|
||||||
|
return {"audio_file": item["audio_file"], "energy": energy}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def precompute(self, num_workers=0):
|
||||||
|
print("[*] Pre-computing energys...")
|
||||||
|
with tqdm.tqdm(total=len(self)) as pbar:
|
||||||
|
batch_size = num_workers if num_workers > 0 else 1
|
||||||
|
# we do not normalize at preproessing
|
||||||
|
normalize_energy = self.normalize_energy
|
||||||
|
self.normalize_energy = False
|
||||||
|
dataloder = torch.utils.data.DataLoader(
|
||||||
|
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||||
|
)
|
||||||
|
computed_data = []
|
||||||
|
for batch in dataloder:
|
||||||
|
energy = batch["energy"]
|
||||||
|
computed_data.append(e for e in energy)
|
||||||
|
pbar.update(batch_size)
|
||||||
|
self.normalize_energy = normalize_energy
|
||||||
|
|
||||||
|
if self.normalize_energy:
|
||||||
|
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||||
|
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
|
||||||
|
energy_stats = {"mean": energy_mean, "std": energy_std}
|
||||||
|
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
|
||||||
|
|
||||||
|
def get_pad_id(self):
|
||||||
|
return self.pad_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_energy_file_path(wav_file, cache_path):
|
||||||
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
|
||||||
|
return energy_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_and_save_energy(ap, wav_file, energy_file=None):
|
||||||
|
wav = ap.load_wav(wav_file)
|
||||||
|
energy = calculate_energy(wav)
|
||||||
|
if energy_file:
|
||||||
|
np.save(energy_file, energy)
|
||||||
|
return energy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_energy_stats(energy_vecs):
|
||||||
|
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
|
||||||
|
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
def load_stats(self, cache_path):
|
||||||
|
stats_path = os.path.join(cache_path, "energy_stats.npy")
|
||||||
|
stats = np.load(stats_path, allow_pickle=True).item()
|
||||||
|
self.mean = stats["mean"].astype(np.float32)
|
||||||
|
self.std = stats["std"].astype(np.float32)
|
||||||
|
|
||||||
|
def normalize(self, energy):
|
||||||
|
zero_idxs = np.where(energy == 0.0)[0]
|
||||||
|
energy = energy - self.mean
|
||||||
|
energy = energy / self.std
|
||||||
|
energy[zero_idxs] = 0.0
|
||||||
|
return energy
|
||||||
|
|
||||||
|
def denormalize(self, energy):
|
||||||
|
zero_idxs = np.where(energy == 0.0)[0]
|
||||||
|
energy *= self.std
|
||||||
|
energy += self.mean
|
||||||
|
energy[zero_idxs] = 0.0
|
||||||
|
return energy
|
||||||
|
|
||||||
|
def compute_or_load(self, wav_file):
|
||||||
|
"""
|
||||||
|
compute energy and return a numpy array of energy values
|
||||||
|
"""
|
||||||
|
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
|
||||||
|
if not os.path.exists(energy_file):
|
||||||
|
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
|
||||||
|
else:
|
||||||
|
energy = np.load(energy_file)
|
||||||
|
return energy.astype(np.float32)
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
audio_file = [item["audio_file"] for item in batch]
|
||||||
|
energys = [item["energy"] for item in batch]
|
||||||
|
energy_lens = [len(item["energy"]) for item in batch]
|
||||||
|
energy_lens_max = max(energy_lens)
|
||||||
|
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
|
||||||
|
for i, energy_len in enumerate(energy_lens):
|
||||||
|
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
|
||||||
|
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
|
||||||
|
|
||||||
|
def print_logs(self, level: int = 0) -> None:
|
||||||
|
indent = "\t" * level
|
||||||
|
print("\n")
|
||||||
|
print(f"{indent}> energyDataset ")
|
||||||
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
|
@ -801,6 +801,10 @@ class ForwardTTSLoss(nn.Module):
|
||||||
self.pitch_loss = MSELossMasked(False)
|
self.pitch_loss = MSELossMasked(False)
|
||||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||||
|
|
||||||
|
if c.model_args.use_energy:
|
||||||
|
self.energy_loss = MSELossMasked(False)
|
||||||
|
self.energy_loss_alpha = c.energy_loss_alpha
|
||||||
|
|
||||||
if c.use_ssim_loss:
|
if c.use_ssim_loss:
|
||||||
self.ssim = SSIMLoss() if c.use_ssim_loss else None
|
self.ssim = SSIMLoss() if c.use_ssim_loss else None
|
||||||
self.ssim_loss_alpha = c.ssim_loss_alpha
|
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||||
|
@ -826,6 +830,8 @@ class ForwardTTSLoss(nn.Module):
|
||||||
dur_target,
|
dur_target,
|
||||||
pitch_output,
|
pitch_output,
|
||||||
pitch_target,
|
pitch_target,
|
||||||
|
energy_output,
|
||||||
|
energy_target,
|
||||||
input_lens,
|
input_lens,
|
||||||
alignment_logprob=None,
|
alignment_logprob=None,
|
||||||
alignment_hard=None,
|
alignment_hard=None,
|
||||||
|
@ -855,6 +861,11 @@ class ForwardTTSLoss(nn.Module):
|
||||||
loss = loss + self.pitch_loss_alpha * pitch_loss
|
loss = loss + self.pitch_loss_alpha * pitch_loss
|
||||||
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||||
|
|
||||||
|
if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0:
|
||||||
|
energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens)
|
||||||
|
loss = loss + self.energy_loss_alpha * energy_loss
|
||||||
|
return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss
|
||||||
|
|
||||||
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
||||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||||
loss = loss + self.aligner_loss_alpha * aligner_loss
|
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,9 @@ class ForwardTTSArgs(Coqpit):
|
||||||
use_pitch (bool):
|
use_pitch (bool):
|
||||||
Use pitch predictor to learn the pitch. Defaults to True.
|
Use pitch predictor to learn the pitch. Defaults to True.
|
||||||
|
|
||||||
|
use_energy (bool):
|
||||||
|
Use energy predictor to learn the energy. Defaults to True.
|
||||||
|
|
||||||
duration_predictor_hidden_channels (int):
|
duration_predictor_hidden_channels (int):
|
||||||
Number of hidden channels in the duration predictor. Defaults to 256.
|
Number of hidden channels in the duration predictor. Defaults to 256.
|
||||||
|
|
||||||
|
@ -63,6 +66,18 @@ class ForwardTTSArgs(Coqpit):
|
||||||
pitch_embedding_kernel_size (int):
|
pitch_embedding_kernel_size (int):
|
||||||
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
energy_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the energy predictor. Defaults to 256.
|
||||||
|
|
||||||
|
energy_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the energy predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
energy_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the energy predictor. Defaults to 3.
|
||||||
|
|
||||||
|
energy_embedding_kernel_size (int):
|
||||||
|
Kernel size of the projection layer in the energy predictor. Defaults to 3.
|
||||||
|
|
||||||
positional_encoding (bool):
|
positional_encoding (bool):
|
||||||
Whether to use positional encoding. Defaults to True.
|
Whether to use positional encoding. Defaults to True.
|
||||||
|
|
||||||
|
@ -114,14 +129,25 @@ class ForwardTTSArgs(Coqpit):
|
||||||
out_channels: int = 80
|
out_channels: int = 80
|
||||||
hidden_channels: int = 384
|
hidden_channels: int = 384
|
||||||
use_aligner: bool = True
|
use_aligner: bool = True
|
||||||
|
# pitch params
|
||||||
use_pitch: bool = True
|
use_pitch: bool = True
|
||||||
pitch_predictor_hidden_channels: int = 256
|
pitch_predictor_hidden_channels: int = 256
|
||||||
pitch_predictor_kernel_size: int = 3
|
pitch_predictor_kernel_size: int = 3
|
||||||
pitch_predictor_dropout_p: float = 0.1
|
pitch_predictor_dropout_p: float = 0.1
|
||||||
pitch_embedding_kernel_size: int = 3
|
pitch_embedding_kernel_size: int = 3
|
||||||
|
|
||||||
|
# energy params
|
||||||
|
use_energy: bool = False
|
||||||
|
energy_predictor_hidden_channels: int = 256
|
||||||
|
energy_predictor_kernel_size: int = 3
|
||||||
|
energy_predictor_dropout_p: float = 0.1
|
||||||
|
energy_embedding_kernel_size: int = 3
|
||||||
|
|
||||||
|
# duration params
|
||||||
duration_predictor_hidden_channels: int = 256
|
duration_predictor_hidden_channels: int = 256
|
||||||
duration_predictor_kernel_size: int = 3
|
duration_predictor_kernel_size: int = 3
|
||||||
duration_predictor_dropout_p: float = 0.1
|
duration_predictor_dropout_p: float = 0.1
|
||||||
|
|
||||||
positional_encoding: bool = True
|
positional_encoding: bool = True
|
||||||
poisitonal_encoding_use_scale: bool = True
|
poisitonal_encoding_use_scale: bool = True
|
||||||
length_scale: int = 1
|
length_scale: int = 1
|
||||||
|
@ -158,7 +184,7 @@ class ForwardTTS(BaseTTS):
|
||||||
- FastPitch
|
- FastPitch
|
||||||
- SpeedySpeech
|
- SpeedySpeech
|
||||||
- FastSpeech
|
- FastSpeech
|
||||||
- TODO: FastSpeech2 (requires average speech energy predictor)
|
- FastSpeech2 (requires average speech energy predictor)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model coqpit class.
|
config (Coqpit): Model coqpit class.
|
||||||
|
@ -187,6 +213,7 @@ class ForwardTTS(BaseTTS):
|
||||||
self.max_duration = self.args.max_duration
|
self.max_duration = self.args.max_duration
|
||||||
self.use_aligner = self.args.use_aligner
|
self.use_aligner = self.args.use_aligner
|
||||||
self.use_pitch = self.args.use_pitch
|
self.use_pitch = self.args.use_pitch
|
||||||
|
self.use_energy = self.args.use_energy
|
||||||
self.binary_loss_weight = 0.0
|
self.binary_loss_weight = 0.0
|
||||||
|
|
||||||
self.length_scale = (
|
self.length_scale = (
|
||||||
|
@ -234,6 +261,20 @@ class ForwardTTS(BaseTTS):
|
||||||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_energy:
|
||||||
|
self.energy_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||||
|
self.args.energy_predictor_hidden_channels,
|
||||||
|
self.args.energy_predictor_kernel_size,
|
||||||
|
self.args.energy_predictor_dropout_p,
|
||||||
|
)
|
||||||
|
self.energy_emb = nn.Conv1d(
|
||||||
|
1,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
kernel_size=self.args.energy_embedding_kernel_size,
|
||||||
|
padding=int((self.args.energy_embedding_kernel_size - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_aligner:
|
if self.args.use_aligner:
|
||||||
self.aligner = AlignmentNetwork(
|
self.aligner = AlignmentNetwork(
|
||||||
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||||
|
@ -440,6 +481,42 @@ class ForwardTTS(BaseTTS):
|
||||||
o_pitch_emb = self.pitch_emb(o_pitch)
|
o_pitch_emb = self.pitch_emb(o_pitch)
|
||||||
return o_pitch_emb, o_pitch
|
return o_pitch_emb, o_pitch
|
||||||
|
|
||||||
|
def _forward_energy_predictor(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
x_mask: torch.IntTensor,
|
||||||
|
energy: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Energy predictor forward pass.
|
||||||
|
|
||||||
|
1. Predict energy from encoder outputs.
|
||||||
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||||
|
3. Embed average energy values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None.
|
||||||
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_en: :math:`(B, C, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- pitch: :math:`(B, 1, T_{de})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_energy = self.energy_predictor(o_en, x_mask)
|
||||||
|
if energy is not None:
|
||||||
|
avg_energy = average_over_durations(energy, dr)
|
||||||
|
o_energy_emb = self.energy_emb(avg_energy)
|
||||||
|
return o_energy_emb, o_energy, avg_energy
|
||||||
|
o_energy_emb = self.energy_emb(o_energy)
|
||||||
|
return o_energy_emb, o_energy
|
||||||
|
|
||||||
def _forward_aligner(
|
def _forward_aligner(
|
||||||
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
||||||
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
@ -502,6 +579,7 @@ class ForwardTTS(BaseTTS):
|
||||||
y: torch.FloatTensor = None,
|
y: torch.FloatTensor = None,
|
||||||
dr: torch.IntTensor = None,
|
dr: torch.IntTensor = None,
|
||||||
pitch: torch.FloatTensor = None,
|
pitch: torch.FloatTensor = None,
|
||||||
|
energy: torch.FloatTensor = None,
|
||||||
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
|
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Model's forward pass.
|
"""Model's forward pass.
|
||||||
|
@ -513,6 +591,7 @@ class ForwardTTS(BaseTTS):
|
||||||
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
||||||
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
||||||
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
||||||
|
energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None.
|
||||||
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -556,6 +635,12 @@ class ForwardTTS(BaseTTS):
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
||||||
o_en = o_en + o_pitch_emb
|
o_en = o_en + o_pitch_emb
|
||||||
|
# energy predictor pass
|
||||||
|
o_energy = None
|
||||||
|
avg_energy = None
|
||||||
|
if self.args.use_energy:
|
||||||
|
o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr)
|
||||||
|
o_en = o_en + o_energy_emb
|
||||||
# decoder pass
|
# decoder pass
|
||||||
o_de, attn = self._forward_decoder(
|
o_de, attn = self._forward_decoder(
|
||||||
o_en, dr, x_mask, y_lengths, g=None
|
o_en, dr, x_mask, y_lengths, g=None
|
||||||
|
@ -567,6 +652,8 @@ class ForwardTTS(BaseTTS):
|
||||||
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
|
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
|
||||||
"pitch_avg": o_pitch,
|
"pitch_avg": o_pitch,
|
||||||
"pitch_avg_gt": avg_pitch,
|
"pitch_avg_gt": avg_pitch,
|
||||||
|
"energy_avg": o_energy,
|
||||||
|
"energy_avg_gt": avg_energy,
|
||||||
"alignments": attn, # [B, T_de, T_en]
|
"alignments": attn, # [B, T_de, T_en]
|
||||||
"alignment_soft": alignment_soft,
|
"alignment_soft": alignment_soft,
|
||||||
"alignment_mas": alignment_mas,
|
"alignment_mas": alignment_mas,
|
||||||
|
@ -604,12 +691,18 @@ class ForwardTTS(BaseTTS):
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||||
o_en = o_en + o_pitch_emb
|
o_en = o_en + o_pitch_emb
|
||||||
|
# energy predictor pass
|
||||||
|
o_energy = None
|
||||||
|
if self.args.use_energy:
|
||||||
|
o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask)
|
||||||
|
o_en = o_en + o_energy_emb
|
||||||
# decoder pass
|
# decoder pass
|
||||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": o_de,
|
"model_outputs": o_de,
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
"pitch": o_pitch,
|
"pitch": o_pitch,
|
||||||
|
"energy": o_energy,
|
||||||
"durations_log": o_dr_log,
|
"durations_log": o_dr_log,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -620,6 +713,7 @@ class ForwardTTS(BaseTTS):
|
||||||
mel_input = batch["mel_input"]
|
mel_input = batch["mel_input"]
|
||||||
mel_lengths = batch["mel_lengths"]
|
mel_lengths = batch["mel_lengths"]
|
||||||
pitch = batch["pitch"] if self.args.use_pitch else None
|
pitch = batch["pitch"] if self.args.use_pitch else None
|
||||||
|
energy = batch["energy"] if self.args.use_energy else None
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
durations = batch["durations"]
|
durations = batch["durations"]
|
||||||
|
@ -627,7 +721,14 @@ class ForwardTTS(BaseTTS):
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_lengths,
|
||||||
|
y=mel_input,
|
||||||
|
dr=durations,
|
||||||
|
pitch=pitch,
|
||||||
|
energy=energy,
|
||||||
|
aux_input=aux_input,
|
||||||
)
|
)
|
||||||
# use aligner's output as the duration target
|
# use aligner's output as the duration target
|
||||||
if self.use_aligner:
|
if self.use_aligner:
|
||||||
|
@ -643,6 +744,8 @@ class ForwardTTS(BaseTTS):
|
||||||
dur_target=durations,
|
dur_target=durations,
|
||||||
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
||||||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||||
|
energy_output=outputs["energy_avg"] if self.use_energy else None,
|
||||||
|
energy_target=outputs["energy_avg_gt"] if self.use_energy else None,
|
||||||
input_lens=text_lengths,
|
input_lens=text_lengths,
|
||||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||||
alignment_soft=outputs["alignment_soft"],
|
alignment_soft=outputs["alignment_soft"],
|
||||||
|
@ -683,6 +786,17 @@ class ForwardTTS(BaseTTS):
|
||||||
}
|
}
|
||||||
figures.update(pitch_figures)
|
figures.update(pitch_figures)
|
||||||
|
|
||||||
|
# plot energy figures
|
||||||
|
if self.args.use_energy:
|
||||||
|
energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy())
|
||||||
|
energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy())
|
||||||
|
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||||
|
energy_figures = {
|
||||||
|
"energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False),
|
||||||
|
"energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False),
|
||||||
|
}
|
||||||
|
figures.update(energy_figures)
|
||||||
|
|
||||||
# plot the attention mask computed from the predicted durations
|
# plot the attention mask computed from the predicted durations
|
||||||
if "attn_durations" in outputs:
|
if "attn_durations" in outputs:
|
||||||
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
||||||
|
|
|
@ -123,6 +123,39 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False):
|
||||||
|
"""Plot energy curves on top of the input characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
energy (np.array): energy values.
|
||||||
|
chars (str): Characters to place to the x-axis.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
energy: :math:`(T,)`
|
||||||
|
"""
|
||||||
|
old_fig_size = plt.rcParams["figure.figsize"]
|
||||||
|
if fig_size is not None:
|
||||||
|
plt.rcParams["figure.figsize"] = fig_size
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
x = np.array(range(len(chars)))
|
||||||
|
my_xticks = chars
|
||||||
|
plt.xticks(x, my_xticks)
|
||||||
|
|
||||||
|
ax.set_xlabel("characters")
|
||||||
|
ax.set_ylabel("freq")
|
||||||
|
|
||||||
|
ax2 = ax.twinx()
|
||||||
|
ax2.plot(energy, linewidth=5.0, color="red")
|
||||||
|
ax2.set_ylabel("energy")
|
||||||
|
|
||||||
|
plt.rcParams["figure.figsize"] = old_fig_size
|
||||||
|
if not output_fig:
|
||||||
|
plt.close()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
alignment,
|
alignment,
|
||||||
postnet_output,
|
postnet_output,
|
||||||
|
|
|
@ -4,7 +4,7 @@ import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from librosa import pyin
|
from librosa import magphase, pyin
|
||||||
|
|
||||||
# For using kwargs
|
# For using kwargs
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
@ -303,6 +303,27 @@ def compute_f0(
|
||||||
return f0
|
return f0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
|
||||||
|
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||||
|
Returns:
|
||||||
|
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
|
||||||
|
Examples:
|
||||||
|
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||||
|
>>> from TTS.config import BaseAudioConfig
|
||||||
|
>>> from TTS.utils.audio import AudioProcessor
|
||||||
|
>>> conf = BaseAudioConfig()
|
||||||
|
>>> ap = AudioProcessor(**conf)
|
||||||
|
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||||
|
>>> energy = ap.compute_energy(wav)
|
||||||
|
"""
|
||||||
|
x = stft(y=y, **kwargs)
|
||||||
|
mag, _ = magphase(x)
|
||||||
|
energy = np.sqrt(np.sum(mag**2, axis=0))
|
||||||
|
return energy
|
||||||
|
|
||||||
|
|
||||||
### Audio Processing ###
|
### Audio Processing ###
|
||||||
def find_endpoint(
|
def find_endpoint(
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.forward_tts import ForwardTTS
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# init configs
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
formatter="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||||
|
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Fastspeech2Config(
|
||||||
|
run_name="fastspeech2_ljspeech",
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=32,
|
||||||
|
eval_batch_size=16,
|
||||||
|
num_loader_workers=8,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
compute_f0=True,
|
||||||
|
f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||||
|
compute_energy=True,
|
||||||
|
energy_cache_path=os.path.join(output_path, "energy_cache"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
precompute_num_workers=4,
|
||||||
|
print_step=50,
|
||||||
|
print_eval=False,
|
||||||
|
mixed_precision=False,
|
||||||
|
max_seq_len=500000,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=[dataset_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute alignments
|
||||||
|
if not config.model_args.use_aligner:
|
||||||
|
manager = ModelManager()
|
||||||
|
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||||
|
# TODO: make compute_attention python callable
|
||||||
|
os.system(
|
||||||
|
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
|
# Audio processor is used for feature extraction and audio I/O.
|
||||||
|
# It mainly serves to the dataloader and the training loggers.
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
|
||||||
|
# INITIALIZE THE TOKENIZER
|
||||||
|
# Tokenizer is used to convert text to sequences of token IDs.
|
||||||
|
# If characters are not defined in the config, default characters are passed to the config
|
||||||
|
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||||
|
|
||||||
|
# LOAD DATA SAMPLES
|
||||||
|
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||||
|
# You can define your custom sample loader returning the list of samples.
|
||||||
|
# Or define your custom formatter and pass it to the `load_tts_samples`.
|
||||||
|
# Check `TTS.tts.datasets.load_tts_samples` for more details.
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
dataset_config,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the model
|
||||||
|
model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||||
|
)
|
||||||
|
trainer.fit()
|
|
@ -0,0 +1,95 @@
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from trainer import get_last_checkpoint
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Fastspeech2Config(
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
|
compute_f0=True,
|
||||||
|
compute_energy=True,
|
||||||
|
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.model_args.use_speaker_embedding = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# Inference using TTS API
|
||||||
|
continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
speaker_id = "ljspeech-1"
|
||||||
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
# Check integrity of the config
|
||||||
|
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_loaded = json.load(f)
|
||||||
|
assert config_loaded["characters"] is not None
|
||||||
|
assert config_loaded["output_path"] in continue_path
|
||||||
|
assert config_loaded["test_delay_epochs"] == 0
|
||||||
|
|
||||||
|
# Load the model and run inference
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||||
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,94 @@
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from trainer import get_last_checkpoint
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Fastspeech2Config(
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
|
compute_f0=True,
|
||||||
|
compute_energy=True,
|
||||||
|
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
use_speaker_embedding=False,
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
config.model_args.use_speaker_embedding = False
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# Inference using TTS API
|
||||||
|
continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
|
||||||
|
# Check integrity of the config
|
||||||
|
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_loaded = json.load(f)
|
||||||
|
assert config_loaded["characters"] is not None
|
||||||
|
assert config_loaded["output_path"] in continue_path
|
||||||
|
assert config_loaded["test_delay_epochs"] == 0
|
||||||
|
|
||||||
|
# Load the model and run inference
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||||
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue