Add pitch predictor

This commit is contained in:
Edresson Casanova 2022-05-16 21:53:49 +00:00
parent dcd0d1f6a1
commit d94b8bac02
5 changed files with 260 additions and 2 deletions

View File

@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig):
mel_loss_alpha: float = 45.0 mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0
consistency_loss_alpha: float = 1.0 consistency_loss_alpha: float = 1.0
pitch_loss_alpha: float = 5.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True
@ -149,6 +150,10 @@ class VitsConfig(BaseTTSConfig):
d_vector_file: str = None d_vector_file: str = None
d_vector_dim: int = None d_vector_dim: int = None
# dataset configs
compute_f0: bool = False
f0_cache_path: str = None
def __post_init__(self): def __post_init__(self):
for key, val in self.model_args.items(): for key, val in self.model_args.items():
if hasattr(self, key): if hasattr(self, key):

View File

@ -122,6 +122,7 @@ class TTSDataset(Dataset):
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path self.f0_cache_path = f0_cache_path
self.precompute_num_workers = precompute_num_workers
self.min_audio_len = min_audio_len self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len self.max_audio_len = max_audio_len
self.min_text_len = min_text_len self.min_text_len = min_text_len

View File

@ -527,6 +527,7 @@ class AlignTTSLoss(nn.Module):
class VitsGeneratorLoss(nn.Module): class VitsGeneratorLoss(nn.Module):
def __init__(self, c: Coqpit): def __init__(self, c: Coqpit):
super().__init__() super().__init__()
self.pitch_loss_alpha = c.pitch_loss_alpha
self.kl_loss_alpha = c.kl_loss_alpha self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha
@ -606,6 +607,7 @@ class VitsGeneratorLoss(nn.Module):
gt_cons_emb=None, gt_cons_emb=None,
syn_cons_emb=None, syn_cons_emb=None,
loss_spk_reversal_classifier=None, loss_spk_reversal_classifier=None,
pitch_loss=None,
): ):
""" """
Shapes: Shapes:
@ -644,6 +646,10 @@ class VitsGeneratorLoss(nn.Module):
if loss_spk_reversal_classifier is not None: if loss_spk_reversal_classifier is not None:
loss += loss_spk_reversal_classifier loss += loss_spk_reversal_classifier
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
if pitch_loss is not None:
pitch_loss = pitch_loss * self.pitch_loss_alpha
loss += pitch_loss
return_dict["pitch_loss"] = pitch_loss
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen

View File

@ -1,5 +1,7 @@
import math import math
import os import os
import numpy as np
import pyworld as pw
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -16,7 +18,7 @@ from torch.utils.data import DataLoader
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
@ -24,6 +26,8 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import prepare_data
from TTS.tts.utils.helpers import average_over_durations
from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
@ -185,16 +189,66 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
spec = amp_to_db(spec) spec = amp_to_db(spec)
return spec return spec
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform.
Returns:
np.ndarray: Pitch.
"""
# align F0 length to the spectrogram length
if len(x) % hop_length == 0:
x = np.pad(x, (0, hop_length // 2), mode="reflect")
f0, t = pw.dio(
x.astype(np.double),
fs=sample_rate,
f0_ceil=pitch_fmax,
frame_period=1000 * hop_length / sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
return f0
############################## ##############################
# DATASET # DATASET
############################## ##############################
class VITSF0Dataset(F0Dataset):
def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
def compute_or_load(self, wav_file):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
def _compute_and_save_pitch(self, wav_file, pitch_file=None):
print(wav_file, pitch_file)
wav, _ = load_audio(wav_file)
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
class VitsDataset(TTSDataset): class VitsDataset(TTSDataset):
def __init__(self, *args, **kwargs): def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.pad_id = self.tokenizer.characters.pad_id self.pad_id = self.tokenizer.characters.pad_id
self.f0_dataset = VITSF0Dataset(config,
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
)
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.samples[idx] item = self.samples[idx]
@ -205,6 +259,11 @@ class VitsDataset(TTSDataset):
token_ids = self.get_token_ids(idx, item["text"]) token_ids = self.get_token_ids(idx, item["text"])
# get f0 values
f0 = None
if self.compute_f0:
f0 = self.get_f0(idx)["f0"]
# after phonemization the text length may change # after phonemization the text length may change
# this is a shameful 🤭 hack to prevent longer phonemes # this is a shameful 🤭 hack to prevent longer phonemes
# TODO: find a better fix # TODO: find a better fix
@ -220,6 +279,7 @@ class VitsDataset(TTSDataset):
"wav_file": wav_filename, "wav_file": wav_filename,
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
"language_name": item["language"], "language_name": item["language"],
"pitch": f0,
} }
@property @property
@ -272,6 +332,14 @@ class VitsDataset(TTSDataset):
wav = batch["wav"][i] wav = batch["wav"][i]
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
# format F0
if self.compute_f0:
pitch = prepare_data(batch["pitch"])
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
return { return {
"tokens": token_padded, "tokens": token_padded,
@ -280,6 +348,7 @@ class VitsDataset(TTSDataset):
"waveform": wav_padded, # (B x T) "waveform": wav_padded, # (B x T)
"waveform_lens": wav_lens, # (B) "waveform_lens": wav_lens, # (B)
"waveform_rel_lens": wav_rel_lens, "waveform_rel_lens": wav_rel_lens,
"pitch": pitch,
"speaker_names": batch["speaker_name"], "speaker_names": batch["speaker_name"],
"language_names": batch["language_name"], "language_names": batch["language_name"],
"audio_files": batch["wav_file"], "audio_files": batch["wav_file"],
@ -437,6 +506,21 @@ class VitsArgs(Coqpit):
encoder_model_path (str): encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
use_pitch (bool):
Use pitch predictor to learn the pitch. Defaults to False.
pitch_predictor_hidden_channels (int):
Number of hidden channels in the pitch predictor. Defaults to 256.
pitch_predictor_dropout_p (float):
Dropout rate for the pitch predictor. Defaults to 0.1.
pitch_predictor_kernel_size (int):
Kernel size of conv layers in the pitch predictor. Defaults to 3.
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
condition_dp_on_speaker (bool): condition_dp_on_speaker (bool):
Condition the duration predictor on the speaker embedding. Defaults to True. Condition the duration predictor on the speaker embedding. Defaults to True.
@ -509,6 +593,14 @@ class VitsArgs(Coqpit):
prosody_encoder_num_heads: int = 1 prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5 prosody_encoder_num_tokens: int = 5
# Pitch predictor
use_pitch: bool = False
pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
detach_pp_input: bool = False
detach_dp_input: bool = True detach_dp_input: bool = True
use_language_embedding: bool = False use_language_embedding: bool = False
embedded_language_dim: int = 4 embedded_language_dim: int = 4
@ -639,6 +731,21 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
if self.args.use_pitch:
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
cond_channels=dp_cond_embedding_dim,
)
self.pitch_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
self.prosody_encoder = GST( self.prosody_encoder = GST(
num_mel=self.args.hidden_channels, num_mel=self.args.hidden_channels,
@ -878,6 +985,47 @@ class Vits(BaseTTS):
g = speaker_ids if speaker_ids is not None else d_vectors g = speaker_ids if speaker_ids is not None else d_vectors
return g return g
def forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
g_pp: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average pitch values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(
o_en,
x_mask,
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
)
print(o_pitch.shape, pitch.shape, dr.shape)
avg_pitch = average_over_durations(pitch, dr.squeeze())
o_pitch_emb = self.pitch_emb(avg_pitch)
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
return pitch_loss
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path # find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
@ -920,6 +1068,7 @@ class Vits(BaseTTS):
y: torch.tensor, y: torch.tensor,
y_lengths: torch.tensor, y_lengths: torch.tensor,
waveform: torch.tensor, waveform: torch.tensor,
pitch: torch.tensor,
aux_input={ aux_input={
"d_vectors": None, "d_vectors": None,
"speaker_ids": None, "speaker_ids": None,
@ -1011,6 +1160,9 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
if self.args.use_pitch:
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
# expand prior # expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
@ -1063,6 +1215,7 @@ class Vits(BaseTTS):
"syn_cons_emb": syn_cons_emb, "syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids, "slice_ids": slice_ids,
"loss_spk_reversal_classifier": l_pros_speaker, "loss_spk_reversal_classifier": l_pros_speaker,
"pitch_loss": pitch_loss,
} }
) )
return outputs return outputs
@ -1281,6 +1434,7 @@ class Vits(BaseTTS):
emotion_embeddings = batch["emotion_embeddings"] emotion_embeddings = batch["emotion_embeddings"]
emotion_ids = batch["emotion_ids"] emotion_ids = batch["emotion_ids"]
waveform = batch["waveform"] waveform = batch["waveform"]
pitch = batch["pitch"]
# generator pass # generator pass
outputs = self.forward( outputs = self.forward(
@ -1289,6 +1443,7 @@ class Vits(BaseTTS):
spec, spec,
spec_lens, spec_lens,
waveform, waveform,
pitch,
aux_input={ aux_input={
"d_vectors": d_vectors, "d_vectors": d_vectors,
"speaker_ids": speaker_ids, "speaker_ids": speaker_ids,
@ -1358,6 +1513,7 @@ class Vits(BaseTTS):
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"], loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
pitch_loss=self.model_outputs_cache["pitch_loss"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -1613,6 +1769,7 @@ class Vits(BaseTTS):
else: else:
# init dataloader # init dataloader
dataset = VitsDataset( dataset = VitsDataset(
config=config,
samples=samples, samples=samples,
# batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_text_len=config.min_text_len, min_text_len=config.min_text_len,
@ -1624,6 +1781,8 @@ class Vits(BaseTTS):
verbose=verbose, verbose=verbose,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest, start_by_longest=config.start_by_longest,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
) )
# wait all the DDP process to be ready # wait all the DDP process to be ready

View File

@ -0,0 +1,87 @@
import glob
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
compute_f0=True,
f0_cache_path="tests/data/ljspeech/f0_cache/",
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 64
# pitch predictor
config.model_args.use_pitch = True
config.model_args.condition_dp_on_speaker = True
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)