Add pitch predictor

This commit is contained in:
Edresson Casanova 2022-05-16 21:53:49 +00:00
parent 5859e6474c
commit 6a573065f4
6 changed files with 261 additions and 2 deletions

View File

@ -74,6 +74,7 @@ def extract_aligments(
np.save(align_file_path, alignment)
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data, speaker_manager

View File

@ -121,6 +121,7 @@ class VitsConfig(BaseTTSConfig):
disc_latent_loss_alpha: float = 5.0
gen_latent_loss_alpha: float = 5.0
feat_latent_loss_alpha: float = 108.0
pitch_loss_alpha: float = 5.0
# data loader params
return_wav: bool = True
@ -155,6 +156,10 @@ class VitsConfig(BaseTTSConfig):
d_vector_file: str = None
d_vector_dim: int = None
# dataset configs
compute_f0: bool = False
f0_cache_path: str = None
def __post_init__(self):
for key, val in self.model_args.items():
if hasattr(self, key):

View File

@ -122,6 +122,7 @@ class TTSDataset(Dataset):
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.precompute_num_workers = precompute_num_workers
self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len

View File

@ -585,6 +585,7 @@ class AlignTTSLoss(nn.Module):
class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.pitch_loss_alpha = c.pitch_loss_alpha
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
@ -680,6 +681,7 @@ class VitsGeneratorLoss(nn.Module):
scores_disc_mp=None,
feats_disc_mp=None,
feats_disc_zp=None,
pitch_loss=None,
):
"""
Shapes:
@ -755,6 +757,11 @@ class VitsGeneratorLoss(nn.Module):
loss += kl_vae_loss
return_dict["loss_kl_vae"] = kl_vae_loss
if pitch_loss is not None:
pitch_loss = pitch_loss * self.pitch_loss_alpha
loss += pitch_loss
return_dict["pitch_loss"] = pitch_loss
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -1,6 +1,7 @@
import math
import os
import numpy as np
import pyworld as pw
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Tuple, Union
@ -17,7 +18,7 @@ from torch.utils.data import DataLoader
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
@ -26,6 +27,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import prepare_data
from TTS.tts.utils.helpers import average_over_durations
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
@ -204,18 +207,68 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
spec = amp_to_db(spec)
return spec
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform.
Returns:
np.ndarray: Pitch.
"""
# align F0 length to the spectrogram length
if len(x) % hop_length == 0:
x = np.pad(x, (0, hop_length // 2), mode="reflect")
f0, t = pw.dio(
x.astype(np.double),
fs=sample_rate,
f0_ceil=pitch_fmax,
frame_period=1000 * hop_length / sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
return f0
##############################
# DATASET
##############################
class VITSF0Dataset(F0Dataset):
def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
def compute_or_load(self, wav_file):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
def _compute_and_save_pitch(self, wav_file, pitch_file=None):
print(wav_file, pitch_file)
wav, _ = load_audio(wav_file)
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
class VitsDataset(TTSDataset):
def __init__(self, model_args, *args, **kwargs):
def __init__(self, model_args, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pad_id = self.tokenizer.characters.pad_id
self.model_args = model_args
self.f0_dataset = VITSF0Dataset(config,
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
)
def __getitem__(self, idx):
item = self.samples[idx]
raw_text = item["text"]
@ -229,6 +282,11 @@ class VitsDataset(TTSDataset):
token_ids = self.get_token_ids(idx, item["text"])
# get f0 values
f0 = None
if self.compute_f0:
f0 = self.get_f0(idx)["f0"]
# after phonemization the text length may change
# this is a shameful 🤭 hack to prevent longer phonemes
# TODO: find a better fix
@ -244,6 +302,7 @@ class VitsDataset(TTSDataset):
"wav_file": wav_filename,
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"pitch": f0,
}
@property
@ -297,6 +356,14 @@ class VitsDataset(TTSDataset):
wav = batch["wav"][i]
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
# format F0
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
return {
"tokens": token_padded,
"token_lens": token_lens,
@ -304,6 +371,7 @@ class VitsDataset(TTSDataset):
"waveform": wav_padded, # (B x T)
"waveform_lens": wav_lens, # (B)
"waveform_rel_lens": wav_rel_lens,
"pitch": pitch,
"speaker_names": batch["speaker_name"],
"language_names": batch["language_name"],
"audio_files": batch["wav_file"],
@ -464,6 +532,21 @@ class VitsArgs(Coqpit):
encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
use_pitch (bool):
Use pitch predictor to learn the pitch. Defaults to False.
pitch_predictor_hidden_channels (int):
Number of hidden channels in the pitch predictor. Defaults to 256.
pitch_predictor_dropout_p (float):
Dropout rate for the pitch predictor. Defaults to 0.1.
pitch_predictor_kernel_size (int):
Kernel size of conv layers in the pitch predictor. Defaults to 3.
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
condition_dp_on_speaker (bool):
Condition the duration predictor on the speaker embedding. Defaults to True.
@ -571,6 +654,14 @@ class VitsArgs(Coqpit):
use_latent_discriminator: bool = False
use_avg_feature_on_latent_discriminator: bool = False
# Pitch predictor
use_pitch: bool = False
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
detach_pp_input: bool = False
detach_dp_input: bool = True
use_language_embedding: bool = False
embedded_language_dim: int = 4
@ -717,6 +808,21 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim,
)
if self.args.use_pitch:
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
cond_channels=dp_cond_embedding_dim,
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.use_prosody_encoder:
if self.args.use_pros_enc_input_as_pros_emb:
self.prosody_embedding_squeezer = nn.Linear(
@ -1085,6 +1191,47 @@ class Vits(BaseTTS):
g = speaker_ids if speaker_ids is not None else d_vectors
return g
def forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
g_pp: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average pitch values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(
o_en,
x_mask,
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
)
print(o_pitch.shape, pitch.shape, dr.shape)
avg_pitch = average_over_durations(pitch, dr.squeeze())
o_pitch_emb = self.pitch_emb(avg_pitch)
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
return pitch_loss
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
@ -1144,6 +1291,7 @@ class Vits(BaseTTS):
y: torch.tensor,
y_lengths: torch.tensor,
waveform: torch.tensor,
pitch: torch.tensor,
aux_input={
"d_vectors": None,
"speaker_ids": None,
@ -1284,6 +1432,9 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
if self.args.use_pitch:
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
# expand prior
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
@ -1359,6 +1510,7 @@ class Vits(BaseTTS):
"loss_prosody_enc_emo_classifier": l_pros_emotion,
"loss_text_enc_spk_rev_classifier": l_text_speaker,
"loss_text_enc_emo_classifier": l_text_emotion,
"pitch_loss": pitch_loss,
}
)
return outputs
@ -1636,6 +1788,7 @@ class Vits(BaseTTS):
emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"]
pitch = batch["pitch"]
# generator pass
outputs = self.forward(
@ -1644,6 +1797,7 @@ class Vits(BaseTTS):
spec,
spec_lens,
waveform,
pitch,
aux_input={
"d_vectors": d_vectors,
"speaker_ids": speaker_ids,
@ -1738,6 +1892,7 @@ class Vits(BaseTTS):
scores_disc_mp=scores_disc_mp,
feats_disc_mp=feats_disc_mp,
feats_disc_zp=feats_disc_zp,
pitch_loss=self.model_outputs_cache["pitch_loss"],
)
return self.model_outputs_cache, loss_dict
@ -2082,6 +2237,7 @@ class Vits(BaseTTS):
# init dataloader
dataset = VitsDataset(
model_args=self.args,
config=config,
samples=samples,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_text_len=config.min_text_len,
@ -2093,6 +2249,8 @@ class Vits(BaseTTS):
verbose=verbose,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
)
# wait all the DDP process to be ready

View File

@ -0,0 +1,87 @@
import glob
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
compute_f0=True,
f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 64
# pitch predictor
config.model_args.use_pitch = True
config.model_args.condition_dp_on_speaker = True
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)