mirror of https://github.com/coqui-ai/TTS.git
Add pitch predictor
This commit is contained in:
parent
dcd0d1f6a1
commit
d94b8bac02
|
@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
consistency_loss_alpha: float = 1.0
|
consistency_loss_alpha: float = 1.0
|
||||||
|
pitch_loss_alpha: float = 5.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
@ -149,6 +150,10 @@ class VitsConfig(BaseTTSConfig):
|
||||||
d_vector_file: str = None
|
d_vector_file: str = None
|
||||||
d_vector_dim: int = None
|
d_vector_dim: int = None
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0: bool = False
|
||||||
|
f0_cache_path: str = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
for key, val in self.model_args.items():
|
for key, val in self.model_args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
|
|
|
@ -122,6 +122,7 @@ class TTSDataset(Dataset):
|
||||||
self.return_wav = return_wav
|
self.return_wav = return_wav
|
||||||
self.compute_f0 = compute_f0
|
self.compute_f0 = compute_f0
|
||||||
self.f0_cache_path = f0_cache_path
|
self.f0_cache_path = f0_cache_path
|
||||||
|
self.precompute_num_workers = precompute_num_workers
|
||||||
self.min_audio_len = min_audio_len
|
self.min_audio_len = min_audio_len
|
||||||
self.max_audio_len = max_audio_len
|
self.max_audio_len = max_audio_len
|
||||||
self.min_text_len = min_text_len
|
self.min_text_len = min_text_len
|
||||||
|
|
|
@ -527,6 +527,7 @@ class AlignTTSLoss(nn.Module):
|
||||||
class VitsGeneratorLoss(nn.Module):
|
class VitsGeneratorLoss(nn.Module):
|
||||||
def __init__(self, c: Coqpit):
|
def __init__(self, c: Coqpit):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||||
self.kl_loss_alpha = c.kl_loss_alpha
|
self.kl_loss_alpha = c.kl_loss_alpha
|
||||||
self.gen_loss_alpha = c.gen_loss_alpha
|
self.gen_loss_alpha = c.gen_loss_alpha
|
||||||
self.feat_loss_alpha = c.feat_loss_alpha
|
self.feat_loss_alpha = c.feat_loss_alpha
|
||||||
|
@ -606,6 +607,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
gt_cons_emb=None,
|
gt_cons_emb=None,
|
||||||
syn_cons_emb=None,
|
syn_cons_emb=None,
|
||||||
loss_spk_reversal_classifier=None,
|
loss_spk_reversal_classifier=None,
|
||||||
|
pitch_loss=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -644,6 +646,10 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
if loss_spk_reversal_classifier is not None:
|
if loss_spk_reversal_classifier is not None:
|
||||||
loss += loss_spk_reversal_classifier
|
loss += loss_spk_reversal_classifier
|
||||||
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
|
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
|
||||||
|
if pitch_loss is not None:
|
||||||
|
pitch_loss = pitch_loss * self.pitch_loss_alpha
|
||||||
|
loss += pitch_loss
|
||||||
|
return_dict["pitch_loss"] = pitch_loss
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pyworld as pw
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
@ -16,7 +18,7 @@ from torch.utils.data import DataLoader
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
|
||||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
|
@ -24,6 +26,8 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.data import prepare_data
|
||||||
|
from TTS.tts.utils.helpers import average_over_durations
|
||||||
from TTS.tts.utils.emotions import EmotionManager
|
from TTS.tts.utils.emotions import EmotionManager
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
@ -185,16 +189,66 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
spec = amp_to_db(spec)
|
spec = amp_to_db(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray:
|
||||||
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Pitch.
|
||||||
|
"""
|
||||||
|
# align F0 length to the spectrogram length
|
||||||
|
if len(x) % hop_length == 0:
|
||||||
|
x = np.pad(x, (0, hop_length // 2), mode="reflect")
|
||||||
|
|
||||||
|
f0, t = pw.dio(
|
||||||
|
x.astype(np.double),
|
||||||
|
fs=sample_rate,
|
||||||
|
f0_ceil=pitch_fmax,
|
||||||
|
frame_period=1000 * hop_length / sample_rate,
|
||||||
|
)
|
||||||
|
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
||||||
|
return f0
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# DATASET
|
# DATASET
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
class VITSF0Dataset(F0Dataset):
|
||||||
|
def __init__(self, config, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def compute_or_load(self, wav_file):
|
||||||
|
"""
|
||||||
|
compute pitch and return a numpy array of pitch values
|
||||||
|
"""
|
||||||
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||||
|
if not os.path.exists(pitch_file):
|
||||||
|
pitch = self._compute_and_save_pitch(wav_file, pitch_file)
|
||||||
|
else:
|
||||||
|
pitch = np.load(pitch_file)
|
||||||
|
return pitch.astype(np.float32)
|
||||||
|
|
||||||
|
def _compute_and_save_pitch(self, wav_file, pitch_file=None):
|
||||||
|
print(wav_file, pitch_file)
|
||||||
|
wav, _ = load_audio(wav_file)
|
||||||
|
pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length)
|
||||||
|
if pitch_file:
|
||||||
|
np.save(pitch_file, pitch)
|
||||||
|
return pitch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, config, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.pad_id = self.tokenizer.characters.pad_id
|
self.pad_id = self.tokenizer.characters.pad_id
|
||||||
|
|
||||||
|
self.f0_dataset = VITSF0Dataset(config,
|
||||||
|
samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers
|
||||||
|
)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.samples[idx]
|
item = self.samples[idx]
|
||||||
|
@ -205,6 +259,11 @@ class VitsDataset(TTSDataset):
|
||||||
|
|
||||||
token_ids = self.get_token_ids(idx, item["text"])
|
token_ids = self.get_token_ids(idx, item["text"])
|
||||||
|
|
||||||
|
# get f0 values
|
||||||
|
f0 = None
|
||||||
|
if self.compute_f0:
|
||||||
|
f0 = self.get_f0(idx)["f0"]
|
||||||
|
|
||||||
# after phonemization the text length may change
|
# after phonemization the text length may change
|
||||||
# this is a shameful 🤭 hack to prevent longer phonemes
|
# this is a shameful 🤭 hack to prevent longer phonemes
|
||||||
# TODO: find a better fix
|
# TODO: find a better fix
|
||||||
|
@ -220,6 +279,7 @@ class VitsDataset(TTSDataset):
|
||||||
"wav_file": wav_filename,
|
"wav_file": wav_filename,
|
||||||
"speaker_name": item["speaker_name"],
|
"speaker_name": item["speaker_name"],
|
||||||
"language_name": item["language"],
|
"language_name": item["language"],
|
||||||
|
"pitch": f0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -272,6 +332,14 @@ class VitsDataset(TTSDataset):
|
||||||
|
|
||||||
wav = batch["wav"][i]
|
wav = batch["wav"][i]
|
||||||
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
||||||
|
|
||||||
|
|
||||||
|
# format F0
|
||||||
|
if self.compute_f0:
|
||||||
|
pitch = prepare_data(batch["pitch"])
|
||||||
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||||
|
else:
|
||||||
|
pitch = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tokens": token_padded,
|
"tokens": token_padded,
|
||||||
|
@ -280,6 +348,7 @@ class VitsDataset(TTSDataset):
|
||||||
"waveform": wav_padded, # (B x T)
|
"waveform": wav_padded, # (B x T)
|
||||||
"waveform_lens": wav_lens, # (B)
|
"waveform_lens": wav_lens, # (B)
|
||||||
"waveform_rel_lens": wav_rel_lens,
|
"waveform_rel_lens": wav_rel_lens,
|
||||||
|
"pitch": pitch,
|
||||||
"speaker_names": batch["speaker_name"],
|
"speaker_names": batch["speaker_name"],
|
||||||
"language_names": batch["language_name"],
|
"language_names": batch["language_name"],
|
||||||
"audio_files": batch["wav_file"],
|
"audio_files": batch["wav_file"],
|
||||||
|
@ -437,6 +506,21 @@ class VitsArgs(Coqpit):
|
||||||
encoder_model_path (str):
|
encoder_model_path (str):
|
||||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
|
use_pitch (bool):
|
||||||
|
Use pitch predictor to learn the pitch. Defaults to False.
|
||||||
|
|
||||||
|
pitch_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the pitch predictor. Defaults to 256.
|
||||||
|
|
||||||
|
pitch_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the pitch predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
pitch_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
pitch_embedding_kernel_size (int):
|
||||||
|
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
condition_dp_on_speaker (bool):
|
condition_dp_on_speaker (bool):
|
||||||
Condition the duration predictor on the speaker embedding. Defaults to True.
|
Condition the duration predictor on the speaker embedding. Defaults to True.
|
||||||
|
|
||||||
|
@ -509,6 +593,14 @@ class VitsArgs(Coqpit):
|
||||||
prosody_encoder_num_heads: int = 1
|
prosody_encoder_num_heads: int = 1
|
||||||
prosody_encoder_num_tokens: int = 5
|
prosody_encoder_num_tokens: int = 5
|
||||||
|
|
||||||
|
# Pitch predictor
|
||||||
|
use_pitch: bool = False
|
||||||
|
pitch_predictor_hidden_channels: int = 256
|
||||||
|
pitch_predictor_kernel_size: int = 3
|
||||||
|
pitch_predictor_dropout_p: float = 0.1
|
||||||
|
pitch_embedding_kernel_size: int = 3
|
||||||
|
detach_pp_input: bool = False
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
embedded_language_dim: int = 4
|
embedded_language_dim: int = 4
|
||||||
|
@ -639,6 +731,21 @@ class Vits(BaseTTS):
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_pitch:
|
||||||
|
self.pitch_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||||
|
self.args.pitch_predictor_hidden_channels,
|
||||||
|
self.args.pitch_predictor_kernel_size,
|
||||||
|
self.args.pitch_predictor_dropout_p,
|
||||||
|
cond_channels=dp_cond_embedding_dim,
|
||||||
|
)
|
||||||
|
self.pitch_emb = nn.Conv1d(
|
||||||
|
1,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
kernel_size=self.args.pitch_embedding_kernel_size,
|
||||||
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
self.prosody_encoder = GST(
|
self.prosody_encoder = GST(
|
||||||
num_mel=self.args.hidden_channels,
|
num_mel=self.args.hidden_channels,
|
||||||
|
@ -878,6 +985,47 @@ class Vits(BaseTTS):
|
||||||
g = speaker_ids if speaker_ids is not None else d_vectors
|
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
def forward_pitch_predictor(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
x_mask: torch.IntTensor,
|
||||||
|
pitch: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
g_pp: torch.IntTensor = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Pitch predictor forward pass.
|
||||||
|
|
||||||
|
1. Predict pitch from encoder outputs.
|
||||||
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||||
|
3. Embed average pitch values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||||
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||||
|
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_en: :math:`(B, C, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- pitch: :math:`(B, 1, T_{de})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_pitch = self.pitch_predictor(
|
||||||
|
o_en,
|
||||||
|
x_mask,
|
||||||
|
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||||
|
)
|
||||||
|
print(o_pitch.shape, pitch.shape, dr.shape)
|
||||||
|
avg_pitch = average_over_durations(pitch, dr.squeeze())
|
||||||
|
o_pitch_emb = self.pitch_emb(avg_pitch)
|
||||||
|
pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask))
|
||||||
|
return pitch_loss
|
||||||
|
|
||||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
@ -920,6 +1068,7 @@ class Vits(BaseTTS):
|
||||||
y: torch.tensor,
|
y: torch.tensor,
|
||||||
y_lengths: torch.tensor,
|
y_lengths: torch.tensor,
|
||||||
waveform: torch.tensor,
|
waveform: torch.tensor,
|
||||||
|
pitch: torch.tensor,
|
||||||
aux_input={
|
aux_input={
|
||||||
"d_vectors": None,
|
"d_vectors": None,
|
||||||
"speaker_ids": None,
|
"speaker_ids": None,
|
||||||
|
@ -1011,6 +1160,9 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
|
|
||||||
|
if self.args.use_pitch:
|
||||||
|
pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp)
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
@ -1063,6 +1215,7 @@ class Vits(BaseTTS):
|
||||||
"syn_cons_emb": syn_cons_emb,
|
"syn_cons_emb": syn_cons_emb,
|
||||||
"slice_ids": slice_ids,
|
"slice_ids": slice_ids,
|
||||||
"loss_spk_reversal_classifier": l_pros_speaker,
|
"loss_spk_reversal_classifier": l_pros_speaker,
|
||||||
|
"pitch_loss": pitch_loss,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1281,6 +1434,7 @@ class Vits(BaseTTS):
|
||||||
emotion_embeddings = batch["emotion_embeddings"]
|
emotion_embeddings = batch["emotion_embeddings"]
|
||||||
emotion_ids = batch["emotion_ids"]
|
emotion_ids = batch["emotion_ids"]
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
|
@ -1289,6 +1443,7 @@ class Vits(BaseTTS):
|
||||||
spec,
|
spec,
|
||||||
spec_lens,
|
spec_lens,
|
||||||
waveform,
|
waveform,
|
||||||
|
pitch,
|
||||||
aux_input={
|
aux_input={
|
||||||
"d_vectors": d_vectors,
|
"d_vectors": d_vectors,
|
||||||
"speaker_ids": speaker_ids,
|
"speaker_ids": speaker_ids,
|
||||||
|
@ -1358,6 +1513,7 @@ class Vits(BaseTTS):
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
||||||
|
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1613,6 +1769,7 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = VitsDataset(
|
dataset = VitsDataset(
|
||||||
|
config=config,
|
||||||
samples=samples,
|
samples=samples,
|
||||||
# batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
# batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
min_text_len=config.min_text_len,
|
min_text_len=config.min_text_len,
|
||||||
|
@ -1624,6 +1781,8 @@ class Vits(BaseTTS):
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
start_by_longest=config.start_by_longest,
|
start_by_longest=config.start_by_longest,
|
||||||
|
compute_f0=config.get("compute_f0", False),
|
||||||
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# wait all the DDP process to be ready
|
# wait all the DDP process to be ready
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from trainer import get_last_checkpoint
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
compute_f0=True,
|
||||||
|
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_speaker_embedding = True
|
||||||
|
config.model_args.use_d_vector_file = False
|
||||||
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.speaker_embedding_channels = 128
|
||||||
|
config.model_args.d_vector_dim = 128
|
||||||
|
|
||||||
|
# prosody embedding
|
||||||
|
config.model_args.use_prosody_encoder = True
|
||||||
|
config.model_args.prosody_embedding_dim = 64
|
||||||
|
|
||||||
|
# pitch predictor
|
||||||
|
config.model_args.use_pitch = True
|
||||||
|
config.model_args.condition_dp_on_speaker = True
|
||||||
|
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# Inference using TTS API
|
||||||
|
continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
speaker_id = "ljspeech-1"
|
||||||
|
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
||||||
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
|
||||||
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue