mirror of https://github.com/coqui-ai/TTS.git
Add pitch predictor
This commit is contained in:
parent
5859e6474c
commit
6a573065f4
|
@ -74,6 +74,7 @@ def extract_aligments(
|
|||
np.save(align_file_path, alignment)
|
||||
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, speaker_manager
|
||||
|
|
|
@ -121,6 +121,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
disc_latent_loss_alpha: float = 5.0
|
||||
gen_latent_loss_alpha: float = 5.0
|
||||
feat_latent_loss_alpha: float = 108.0
|
||||
pitch_loss_alpha: float = 5.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
@ -155,6 +156,10 @@ class VitsConfig(BaseTTSConfig):
|
|||
d_vector_file: str = None
|
||||
d_vector_dim: int = None
|
||||
|
||||
# dataset configs
|
||||
compute_f0: bool = False
|
||||
f0_cache_path: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
for key, val in self.model_args.items():
|
||||
if hasattr(self, key):
|
||||
|
|
|
@ -122,6 +122,7 @@ class TTSDataset(Dataset):
|
|||
self.return_wav = return_wav
|
||||
self.compute_f0 = compute_f0
|
||||
self.f0_cache_path = f0_cache_path
|
||||
self.precompute_num_workers = precompute_num_workers
|
||||
self.min_audio_len = min_audio_len
|
||||
self.max_audio_len = max_audio_len
|
||||
self.min_text_len = min_text_len
|
||||
|
|
|
@ -585,6 +585,7 @@ class AlignTTSLoss(nn.Module):
|
|||
class VitsGeneratorLoss(nn.Module):
|
||||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
self.kl_loss_alpha = c.kl_loss_alpha
|
||||
self.gen_loss_alpha = c.gen_loss_alpha
|
||||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
|
@ -680,6 +681,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
scores_disc_mp=None,
|
||||
feats_disc_mp=None,
|
||||
feats_disc_zp=None,
|
||||
pitch_loss=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -755,6 +757,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += kl_vae_loss
|
||||
return_dict["loss_kl_vae"] = kl_vae_loss
|
||||
|
||||
if pitch_loss is not None:
|
||||
pitch_loss = pitch_loss * self.pitch_loss_alpha
|
||||
loss += pitch_loss
|
||||
return_dict["pitch_loss"] = pitch_loss
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import pyworld as pw
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
@ -17,7 +18,7 @@ from torch.utils.data import DataLoader
|
|||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
|
@ -26,6 +27,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
|||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import prepare_data
|
||||
from TTS.tts.utils.helpers import average_over_durations
|
||||
from TTS.tts.utils.emotions import EmotionManager
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
|
@ -204,18 +207,68 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch.
|
||||
"""
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % hop_length == 0:
|
||||
x = np.pad(x, (0, hop_length // 2), mode="reflect")
|
||||
|
||||
f0, t = pw.dio(
|
||||
x.astype(np.double),
|
||||
fs=sample_rate,
|
||||
f0_ceil=pitch_fmax,
|
||||
frame_period=1000 * hop_length / sample_rate,
|
||||
)
|
||||
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
||||
return f0
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
||||
class VITSF0Dataset(F0Dataset):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch.astype(np.float32)
|
||||
|
||||
def _compute_and_save_pitch(self, wav_file, pitch_file=None):
|
||||
print(wav_file, pitch_file)
|
||||
wav, _ = load_audio(wav_file)
|
||||
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length)
|
||||
if pitch_file:
|
||||
np.save(pitch_file, pitch)
|
||||
return pitch
|
||||
|
||||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
def __init__(self, model_args, config, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pad_id = self.tokenizer.characters.pad_id
|
||||
self.model_args = model_args
|
||||
|
||||
self.f0_dataset = VITSF0Dataset(config,
|
||||
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
raw_text = item["text"]
|
||||
|
@ -229,6 +282,11 @@ class VitsDataset(TTSDataset):
|
|||
|
||||
token_ids = self.get_token_ids(idx, item["text"])
|
||||
|
||||
# get f0 values
|
||||
f0 = None
|
||||
if self.compute_f0:
|
||||
f0 = self.get_f0(idx)["f0"]
|
||||
|
||||
# after phonemization the text length may change
|
||||
# this is a shameful 🤭 hack to prevent longer phonemes
|
||||
# TODO: find a better fix
|
||||
|
@ -244,6 +302,7 @@ class VitsDataset(TTSDataset):
|
|||
"wav_file": wav_filename,
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
"pitch": f0,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -297,6 +356,14 @@ class VitsDataset(TTSDataset):
|
|||
wav = batch["wav"][i]
|
||||
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
||||
|
||||
|
||||
# format F0
|
||||
if self.compute_f0:
|
||||
pitch = prepare_data(batch["pitch"])
|
||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
pitch = None
|
||||
|
||||
return {
|
||||
"tokens": token_padded,
|
||||
"token_lens": token_lens,
|
||||
|
@ -304,6 +371,7 @@ class VitsDataset(TTSDataset):
|
|||
"waveform": wav_padded, # (B x T)
|
||||
"waveform_lens": wav_lens, # (B)
|
||||
"waveform_rel_lens": wav_rel_lens,
|
||||
"pitch": pitch,
|
||||
"speaker_names": batch["speaker_name"],
|
||||
"language_names": batch["language_name"],
|
||||
"audio_files": batch["wav_file"],
|
||||
|
@ -464,6 +532,21 @@ class VitsArgs(Coqpit):
|
|||
encoder_model_path (str):
|
||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||
|
||||
use_pitch (bool):
|
||||
Use pitch predictor to learn the pitch. Defaults to False.
|
||||
|
||||
pitch_predictor_hidden_channels (int):
|
||||
Number of hidden channels in the pitch predictor. Defaults to 256.
|
||||
|
||||
pitch_predictor_dropout_p (float):
|
||||
Dropout rate for the pitch predictor. Defaults to 0.1.
|
||||
|
||||
pitch_predictor_kernel_size (int):
|
||||
Kernel size of conv layers in the pitch predictor. Defaults to 3.
|
||||
|
||||
pitch_embedding_kernel_size (int):
|
||||
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||
|
||||
condition_dp_on_speaker (bool):
|
||||
Condition the duration predictor on the speaker embedding. Defaults to True.
|
||||
|
||||
|
@ -571,6 +654,14 @@ class VitsArgs(Coqpit):
|
|||
use_latent_discriminator: bool = False
|
||||
use_avg_feature_on_latent_discriminator: bool = False
|
||||
|
||||
# Pitch predictor
|
||||
use_pitch: bool = False
|
||||
pitch_predictor_hidden_channels: int = 256
|
||||
pitch_predictor_kernel_size: int = 3
|
||||
pitch_predictor_dropout_p: float = 0.1
|
||||
pitch_embedding_kernel_size: int = 3
|
||||
detach_pp_input: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
|
@ -717,6 +808,21 @@ class Vits(BaseTTS):
|
|||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
self.args.pitch_predictor_kernel_size,
|
||||
self.args.pitch_predictor_dropout_p,
|
||||
cond_channels=dp_cond_embedding_dim,
|
||||
)
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.pitch_embedding_kernel_size,
|
||||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.use_pros_enc_input_as_pros_emb:
|
||||
self.prosody_embedding_squeezer = nn.Linear(
|
||||
|
@ -1085,6 +1191,47 @@ class Vits(BaseTTS):
|
|||
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||
return g
|
||||
|
||||
def forward_pitch_predictor(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
pitch: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
g_pp: torch.IntTensor = None,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Pitch predictor forward pass.
|
||||
|
||||
1. Predict pitch from encoder outputs.
|
||||
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||
3. Embed average pitch values.
|
||||
|
||||
Args:
|
||||
o_en (torch.FloatTensor): Encoder output.
|
||||
x_mask (torch.IntTensor): Input sequence mask.
|
||||
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||
|
||||
Shapes:
|
||||
- o_en: :math:`(B, C, T_{en})`
|
||||
- x_mask: :math:`(B, 1, T_{en})`
|
||||
- pitch: :math:`(B, 1, T_{de})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
"""
|
||||
o_pitch = self.pitch_predictor(
|
||||
o_en,
|
||||
x_mask,
|
||||
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||
)
|
||||
print(o_pitch.shape, pitch.shape, dr.shape)
|
||||
avg_pitch = average_over_durations(pitch, dr.squeeze())
|
||||
o_pitch_emb = self.pitch_emb(avg_pitch)
|
||||
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||
return pitch_loss
|
||||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
|
@ -1144,6 +1291,7 @@ class Vits(BaseTTS):
|
|||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
waveform: torch.tensor,
|
||||
pitch: torch.tensor,
|
||||
aux_input={
|
||||
"d_vectors": None,
|
||||
"speaker_ids": None,
|
||||
|
@ -1284,6 +1432,9 @@ class Vits(BaseTTS):
|
|||
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_pitch:
|
||||
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
|
||||
|
||||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
@ -1359,6 +1510,7 @@ class Vits(BaseTTS):
|
|||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
"pitch_loss": pitch_loss,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1636,6 +1788,7 @@ class Vits(BaseTTS):
|
|||
emotion_embeddings = batch["emotion_embeddings"]
|
||||
emotion_ids = batch["emotion_ids"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
|
||||
# generator pass
|
||||
outputs = self.forward(
|
||||
|
@ -1644,6 +1797,7 @@ class Vits(BaseTTS):
|
|||
spec,
|
||||
spec_lens,
|
||||
waveform,
|
||||
pitch,
|
||||
aux_input={
|
||||
"d_vectors": d_vectors,
|
||||
"speaker_ids": speaker_ids,
|
||||
|
@ -1738,6 +1892,7 @@ class Vits(BaseTTS):
|
|||
scores_disc_mp=scores_disc_mp,
|
||||
feats_disc_mp=feats_disc_mp,
|
||||
feats_disc_zp=feats_disc_zp,
|
||||
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -2082,6 +2237,7 @@ class Vits(BaseTTS):
|
|||
# init dataloader
|
||||
dataset = VitsDataset(
|
||||
model_args=self.args,
|
||||
config=config,
|
||||
samples=samples,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_text_len=config.min_text_len,
|
||||
|
@ -2093,6 +2249,8 @@ class Vits(BaseTTS):
|
|||
verbose=verbose,
|
||||
tokenizer=self.tokenizer,
|
||||
start_by_longest=config.start_by_longest,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
compute_f0=True,
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 128
|
||||
|
||||
# prosody embedding
|
||||
config.model_args.use_prosody_encoder = True
|
||||
config.model_args.prosody_embedding_dim = 64
|
||||
|
||||
# pitch predictor
|
||||
config.model_args.use_pitch = True
|
||||
config.model_args.condition_dp_on_speaker = True
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech_test "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue