mirror of https://github.com/coqui-ai/TTS.git
Rename vars in VITS
This commit is contained in:
parent
8f21991a84
commit
e5a9902e85
|
@ -29,6 +29,7 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.generic_utils import count_parameters
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
@ -108,6 +109,7 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||||
pad_mode="reflect",
|
pad_mode="reflect",
|
||||||
normalized=False,
|
normalized=False,
|
||||||
onesided=True,
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
@ -483,7 +485,7 @@ class VitsArgs(Coqpit):
|
||||||
inference_noise_scale_dp: float = 1.0
|
inference_noise_scale_dp: float = 1.0
|
||||||
max_inference_len: int = None
|
max_inference_len: int = None
|
||||||
init_discriminator: bool = True
|
init_discriminator: bool = True
|
||||||
use_spectral_norm_disriminator: bool = False
|
use_spectral_norm_discriminator: bool = False
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
num_speakers: int = 0
|
num_speakers: int = 0
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
|
@ -625,7 +627,8 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.init_discriminator:
|
if self.args.init_discriminator:
|
||||||
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
|
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_discriminator)
|
||||||
|
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
@ -780,21 +783,21 @@ class Vits(BaseTTS):
|
||||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp = logp2 + logp3 + logp1 + logp4
|
logp = logp2 + logp3 + logp1 + logp4
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
mas_attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
attn_durations = attn.sum(3)
|
mas_attn_durations = mas_attn.sum(3)
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
loss_duration = self.duration_predictor(
|
loss_duration = self.duration_predictor(
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
x_mask,
|
x_mask,
|
||||||
attn_durations,
|
mas_attn_durations,
|
||||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||||
)
|
)
|
||||||
loss_duration = loss_duration / torch.sum(x_mask)
|
loss_duration = loss_duration / torch.sum(x_mask)
|
||||||
else:
|
else:
|
||||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
attn_log_durations = torch.log(mas_attn_durations + 1e-6) * x_mask
|
||||||
log_durations = self.duration_predictor(
|
log_durations = self.duration_predictor(
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
x_mask,
|
x_mask,
|
||||||
|
@ -803,7 +806,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
outputs["loss_duration"] = loss_duration
|
outputs["loss_duration"] = loss_duration
|
||||||
return outputs, attn
|
return outputs, mas_attn
|
||||||
|
|
||||||
def forward( # pylint: disable=dangerous-default-value
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
self,
|
self,
|
||||||
|
@ -871,11 +874,11 @@ class Vits(BaseTTS):
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
outputs, mas_attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, m_p])
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, logs_p])
|
||||||
|
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||||
|
@ -907,7 +910,7 @@ class Vits(BaseTTS):
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": mas_attn.squeeze(1),
|
||||||
"m_p": m_p,
|
"m_p": m_p,
|
||||||
"logs_p": logs_p,
|
"logs_p": logs_p,
|
||||||
"z": z,
|
"z": z,
|
||||||
|
@ -1152,7 +1155,8 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
|
@staticmethod
|
||||||
|
def _log(ap, outputs, name_prefix="train"):
|
||||||
y_hat = outputs[1]["model_outputs"]
|
y_hat = outputs[1]["model_outputs"]
|
||||||
y = outputs[1]["waveform_seg"]
|
y = outputs[1]["waveform_seg"]
|
||||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||||
|
@ -1185,7 +1189,7 @@ class Vits(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
"""
|
"""
|
||||||
figures, audios = self._log(self.ap, batch, outputs, "train")
|
figures, audios = self._log(self.ap, outputs, "train")
|
||||||
logger.train_figures(steps, figures)
|
logger.train_figures(steps, figures)
|
||||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
|
@ -1194,7 +1198,7 @@ class Vits(BaseTTS):
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
figures, audios = self._log(self.ap, outputs, "eval")
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
|
@ -1401,9 +1405,7 @@ class Vits(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
List: optimizers.
|
List: optimizers.
|
||||||
"""
|
"""
|
||||||
# select generator parameters
|
|
||||||
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||||
|
|
||||||
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
||||||
optimizer1 = get_optimizer(
|
optimizer1 = get_optimizer(
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
|
|
Loading…
Reference in New Issue