mirror of https://github.com/coqui-ai/TTS.git
Add Determinist decoder on VITS
This commit is contained in:
parent
3165c55fae
commit
18dee1190e
|
@ -122,6 +122,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
gen_latent_loss_alpha: float = 5.0
|
gen_latent_loss_alpha: float = 5.0
|
||||||
feat_latent_loss_alpha: float = 108.0
|
feat_latent_loss_alpha: float = 108.0
|
||||||
pitch_loss_alpha: float = 5.0
|
pitch_loss_alpha: float = 5.0
|
||||||
|
z_decoder_loss_alpha: float = 45.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
|
|
@ -597,6 +597,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
||||||
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
|
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
|
||||||
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
|
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
|
||||||
|
self.z_decoder_loss_alpha = c.z_decoder_loss_alpha
|
||||||
|
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
|
@ -682,6 +683,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_mp=None,
|
feats_disc_mp=None,
|
||||||
feats_disc_zp=None,
|
feats_disc_zp=None,
|
||||||
pitch_loss=None,
|
pitch_loss=None,
|
||||||
|
z_decoder_loss=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -762,6 +764,11 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss += pitch_loss
|
loss += pitch_loss
|
||||||
return_dict["pitch_loss"] = pitch_loss
|
return_dict["pitch_loss"] = pitch_loss
|
||||||
|
|
||||||
|
if z_decoder_loss is not None:
|
||||||
|
z_decoder_loss = z_decoder_loss * self.z_decoder_loss_alpha
|
||||||
|
loss += z_decoder_loss
|
||||||
|
return_dict["z_decoder_loss"] = z_decoder_loss
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
import pyworld as pw
|
import pyworld as pw
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Callable, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -22,6 +22,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
|
||||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
|
from TTS.tts.layers.feed_forward.decoder import Decoder as ZDecoder
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||||
|
@ -687,6 +688,14 @@ class VitsArgs(Coqpit):
|
||||||
use_precomputed_alignments: bool = False
|
use_precomputed_alignments: bool = False
|
||||||
alignments_cache_path: str = ""
|
alignments_cache_path: str = ""
|
||||||
pitch_embedding_dim: int = 0
|
pitch_embedding_dim: int = 0
|
||||||
|
pitch_mean: float = 0.0
|
||||||
|
pitch_std: float = 0.0
|
||||||
|
|
||||||
|
use_z_decoder: bool = False
|
||||||
|
z_decoder_type: str = "fftransformer"
|
||||||
|
z_decoder_params: dict = field(
|
||||||
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
|
)
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
|
@ -838,6 +847,21 @@ class Vits(BaseTTS):
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_z_decoder:
|
||||||
|
dec_extra_inp_dim = self.cond_embedding_dim
|
||||||
|
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||||
|
dec_extra_inp_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
dec_extra_inp_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
|
self.z_decoder = ZDecoder(
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.hidden_channels + dec_extra_inp_dim,
|
||||||
|
self.args.z_decoder_type,
|
||||||
|
self.args.z_decoder_params,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
if self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch_on_enc_input:
|
||||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||||
|
@ -1232,6 +1256,7 @@ class Vits(BaseTTS):
|
||||||
pitch: torch.FloatTensor = None,
|
pitch: torch.FloatTensor = None,
|
||||||
dr: torch.IntTensor = None,
|
dr: torch.IntTensor = None,
|
||||||
g_pp: torch.IntTensor = None,
|
g_pp: torch.IntTensor = None,
|
||||||
|
pitch_transform: Callable=None,
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
"""Pitch predictor forward pass.
|
"""Pitch predictor forward pass.
|
||||||
|
|
||||||
|
@ -1245,6 +1270,7 @@ class Vits(BaseTTS):
|
||||||
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||||
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||||
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
|
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
|
||||||
|
pitch_transform (Callable, optional): Pitch transform function. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||||
|
@ -1267,6 +1293,9 @@ class Vits(BaseTTS):
|
||||||
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if pitch_transform is not None:
|
||||||
|
pred_avg_pitch = pitch_transform(pred_avg_pitch, x_mask.sum(dim=(1,2)), self.args.pitch_mean, self.args.pitch_std)
|
||||||
|
|
||||||
pitch_loss = None
|
pitch_loss = None
|
||||||
pred_avg_pitch_emb = None
|
pred_avg_pitch_emb = None
|
||||||
gt_avg_pitch_emb = None
|
gt_avg_pitch_emb = None
|
||||||
|
@ -1500,6 +1529,30 @@ class Vits(BaseTTS):
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
|
z_decoder_loss = None
|
||||||
|
if self.args.use_z_decoder:
|
||||||
|
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x])
|
||||||
|
# prepare the conditional emb
|
||||||
|
g_dec = g
|
||||||
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
|
if g_dec is None:
|
||||||
|
g_dec = eg
|
||||||
|
else:
|
||||||
|
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
if g_dec is None:
|
||||||
|
g_dec = pros_emb
|
||||||
|
else:
|
||||||
|
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
|
if g_dec is not None:
|
||||||
|
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||||
|
|
||||||
|
# decoder pass
|
||||||
|
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||||
|
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
|
||||||
|
|
||||||
if self.args.use_noise_scale_predictor:
|
if self.args.use_noise_scale_predictor:
|
||||||
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||||
|
@ -1575,6 +1628,7 @@ class Vits(BaseTTS):
|
||||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||||
"pitch_loss": pitch_loss,
|
"pitch_loss": pitch_loss,
|
||||||
|
"z_decoder_loss": z_decoder_loss,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1598,6 +1652,7 @@ class Vits(BaseTTS):
|
||||||
"emotion_ids": None,
|
"emotion_ids": None,
|
||||||
"style_feature": None,
|
"style_feature": None,
|
||||||
},
|
},
|
||||||
|
pitch_transform=None,
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Note:
|
Note:
|
||||||
|
@ -1682,7 +1737,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
pred_avg_pitch_emb = None
|
pred_avg_pitch_emb = None
|
||||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||||
x,
|
x,
|
||||||
|
@ -1720,7 +1775,7 @@ class Vits(BaseTTS):
|
||||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
||||||
m_p = m_p + pred_avg_pitch_emb
|
m_p = m_p + pred_avg_pitch_emb
|
||||||
|
|
||||||
|
|
||||||
|
@ -1745,7 +1800,28 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
if self.args.use_z_decoder:
|
||||||
|
x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
|
||||||
|
# prepare the conditional emb
|
||||||
|
g_dec = g
|
||||||
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
|
if g_dec is None:
|
||||||
|
g_dec = eg
|
||||||
|
else:
|
||||||
|
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
if g_dec is None:
|
||||||
|
g_dec = pros_emb
|
||||||
|
else:
|
||||||
|
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
|
if g_dec is not None:
|
||||||
|
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||||
|
|
||||||
|
# decoder pass
|
||||||
|
z = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||||
|
else:
|
||||||
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
# upsampling if needed
|
# upsampling if needed
|
||||||
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
||||||
|
@ -1892,7 +1968,7 @@ class Vits(BaseTTS):
|
||||||
outputs["model_outputs"].detach(),
|
outputs["model_outputs"].detach(),
|
||||||
outputs["waveform_seg"],
|
outputs["waveform_seg"],
|
||||||
outputs["m_p_unexpanded"].detach(),
|
outputs["m_p_unexpanded"].detach(),
|
||||||
outputs["z_p_avg"].detach(),
|
outputs["z_p_avg"].detach() if outputs["z_p_avg"] is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
|
@ -1940,7 +2016,7 @@ class Vits(BaseTTS):
|
||||||
self.model_outputs_cache["model_outputs"],
|
self.model_outputs_cache["model_outputs"],
|
||||||
self.model_outputs_cache["waveform_seg"],
|
self.model_outputs_cache["waveform_seg"],
|
||||||
self.model_outputs_cache["m_p_unexpanded"],
|
self.model_outputs_cache["m_p_unexpanded"],
|
||||||
self.model_outputs_cache["z_p_avg"].detach(),
|
self.model_outputs_cache["z_p_avg"].detach() if self.model_outputs_cache["z_p_avg"] is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
|
@ -1970,6 +2046,7 @@ class Vits(BaseTTS):
|
||||||
feats_disc_mp=feats_disc_mp,
|
feats_disc_mp=feats_disc_mp,
|
||||||
feats_disc_zp=feats_disc_zp,
|
feats_disc_zp=feats_disc_zp,
|
||||||
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
||||||
|
z_decoder_loss=self.model_outputs_cache["z_decoder_loss"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -2337,6 +2414,10 @@ class Vits(BaseTTS):
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.preprocess_samples()
|
dataset.preprocess_samples()
|
||||||
|
|
||||||
|
if self.args.use_pitch:
|
||||||
|
self.args.pitch_mean = dataset.f0_dataset.mean
|
||||||
|
self.args.pitch_std = dataset.f0_dataset.std
|
||||||
|
|
||||||
# get samplers
|
# get samplers
|
||||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,8 @@ config = VitsConfig(
|
||||||
config.audio.do_trim_silence = True
|
config.audio.do_trim_silence = True
|
||||||
config.audio.trim_db = 60
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
config.model_args.use_z_decoder = True
|
||||||
|
|
||||||
# active multispeaker d-vec mode
|
# active multispeaker d-vec mode
|
||||||
config.model_args.use_d_vector_file = True
|
config.model_args.use_d_vector_file = True
|
||||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
|
|
@ -79,6 +79,7 @@ out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
speaker_id = "ljspeech-1"
|
speaker_id = "ljspeech-1"
|
||||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
|
||||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
|
||||||
run_cli(inference_command)
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue