mirror of https://github.com/coqui-ai/TTS.git
Add voice conversion fine tuning mode
This commit is contained in:
parent
2be38aad3f
commit
cd7639ca70
|
@ -0,0 +1,63 @@
|
||||||
|
"""Find all the unique characters in a dataset"""
|
||||||
|
import argparse
|
||||||
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
|
from TTS.tts.datasets import load_meta_data
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import multiprocessing
|
||||||
|
from TTS.tts.utils.text import text2phone
|
||||||
|
from tqdm.contrib.concurrent import process_map
|
||||||
|
|
||||||
|
def compute_phonemes(item):
|
||||||
|
try:
|
||||||
|
text = item[0]
|
||||||
|
language = item[-1]
|
||||||
|
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
|
||||||
|
except:
|
||||||
|
return []
|
||||||
|
return list(set(ph))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global c
|
||||||
|
# pylint: disable=bad-option-value
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||||
|
"""
|
||||||
|
Example runs:
|
||||||
|
|
||||||
|
python TTS/bin/find_unique_chars.py --config_path config.json
|
||||||
|
""",
|
||||||
|
formatter_class=RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
|
||||||
|
# load all datasets
|
||||||
|
train_items, eval_items = load_meta_data(c.datasets, eval_split=True)
|
||||||
|
items = train_items + eval_items
|
||||||
|
print("Num items:", len(items))
|
||||||
|
# items = items[:1000]
|
||||||
|
|
||||||
|
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||||
|
phones = []
|
||||||
|
for ph in phonemes:
|
||||||
|
phones.extend(ph)
|
||||||
|
phones = set(phones)
|
||||||
|
lower_phones = filter(lambda c: c.islower(), phones)
|
||||||
|
phones_force_lower = [c.lower() for c in phones]
|
||||||
|
phones_force_lower = set(phones_force_lower)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print(f" > Number of unique phonemes: {len(phones)}")
|
||||||
|
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||||
|
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||||
|
print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -599,7 +599,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_fake,
|
feats_disc_fake,
|
||||||
feats_disc_real,
|
feats_disc_real,
|
||||||
loss_duration,
|
loss_duration,
|
||||||
fine_tuning_mode=False,
|
fine_tuning_mode=0,
|
||||||
use_speaker_encoder_as_loss=False,
|
use_speaker_encoder_as_loss=False,
|
||||||
gt_spk_emb=None,
|
gt_spk_emb=None,
|
||||||
syn_spk_emb=None
|
syn_spk_emb=None
|
||||||
|
|
|
@ -149,6 +149,28 @@ class VitsArgs(Coqpit):
|
||||||
|
|
||||||
detach_dp_input (bool):
|
detach_dp_input (bool):
|
||||||
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
Enable/Disable language embedding for multilingual models. Defaults to False.
|
||||||
|
|
||||||
|
embedded_language_dim (int):
|
||||||
|
Number of language embedding channels. Defaults to 4.
|
||||||
|
|
||||||
|
num_languages (int):
|
||||||
|
Number of languages for the language embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
use_speaker_encoder_as_loss (bool):
|
||||||
|
|
||||||
|
|
||||||
|
use_speaker_encoder_as_loss: bool = False
|
||||||
|
speaker_encoder_config_path: str = ""
|
||||||
|
speaker_encoder_model_path: str = ""
|
||||||
|
|
||||||
|
fine_tuning_mode (int):
|
||||||
|
Fine tuning only the vocoder part of the model, while the rest will be frozen. Defaults to 0.
|
||||||
|
Mode 0: disabled;
|
||||||
|
Mode 1: uses the distribution predicted by the encoder and It's recommended for TTS;
|
||||||
|
Mode 2: uses the distribution predicted by the encoder and It's recommended for voice conversion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_chars: int = 100
|
num_chars: int = 100
|
||||||
|
@ -194,10 +216,10 @@ class VitsArgs(Coqpit):
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
embedded_language_dim: int = 4
|
embedded_language_dim: int = 4
|
||||||
num_languages: int = 0
|
num_languages: int = 0
|
||||||
fine_tuning_mode: bool = False
|
|
||||||
use_speaker_encoder_as_loss: bool = False
|
use_speaker_encoder_as_loss: bool = False
|
||||||
speaker_encoder_config_path: str = ""
|
speaker_encoder_config_path: str = ""
|
||||||
speaker_encoder_model_path: str = ""
|
speaker_encoder_model_path: str = ""
|
||||||
|
fine_tuning_mode: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -565,6 +587,7 @@ class Vits(BaseTTS):
|
||||||
y: torch.tensor,
|
y: torch.tensor,
|
||||||
y_lengths: torch.tensor,
|
y_lengths: torch.tensor,
|
||||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||||
|
waveform=None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Forward pass of the model.
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
@ -621,22 +644,50 @@ class Vits(BaseTTS):
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
# get the z after inverse decoder
|
# mode 1: like SC-GlowTTS paper; mode 2: recommended for voice conversion
|
||||||
# ToDo: test if using m_p the result is better (In the SC-GlowTTS paper we used mp instead z_p)
|
if self.args.fine_tuning_mode == 1:
|
||||||
z_f_pred = self.flow(z_p, y_mask, g=g, reverse=True)
|
z_ft = m_p
|
||||||
|
elif self.args.fine_tuning_mode == 2:
|
||||||
|
z_ft = z_p
|
||||||
|
else:
|
||||||
|
raise RuntimeError(" [!] Invalid Fine Tunning Mode !")
|
||||||
|
|
||||||
|
# inverse decoder and get the output
|
||||||
|
z_f_pred = self.flow(z_ft, y_mask, g=g, reverse=True)
|
||||||
z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size)
|
z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size)
|
||||||
|
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
|
wav_seg = segment(
|
||||||
|
waveform.transpose(1, 2),
|
||||||
|
slice_ids * self.config.audio.hop_length,
|
||||||
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
|
# concate generated and GT waveforms
|
||||||
|
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||||
|
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||||
|
|
||||||
|
# split generated and GT speaker embeddings
|
||||||
|
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||||
|
else:
|
||||||
|
gt_spk_emb, syn_spk_emb = None, None
|
||||||
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
|
"loss_duration": 0.0,
|
||||||
"z": z,
|
"z": z,
|
||||||
"z_p": z_p,
|
"z_p": z_p,
|
||||||
"m_p": m_p,
|
"m_p": m_p,
|
||||||
"logs_p": logs_p,
|
"logs_p": logs_p,
|
||||||
"m_q": m_q,
|
"m_q": m_q,
|
||||||
"logs_q": logs_q,
|
"logs_q": logs_q,
|
||||||
|
"waveform_seg": wav_seg,
|
||||||
|
"gt_spk_emb": gt_spk_emb,
|
||||||
|
"syn_spk_emb": syn_spk_emb
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
Loading…
Reference in New Issue