mirror of https://github.com/coqui-ai/TTS.git
Rename vars in VITS
This commit is contained in:
parent
775a6ab6ee
commit
760f045aaa
|
@ -29,6 +29,7 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
@ -125,6 +126,7 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
@ -521,7 +523,7 @@ class VitsArgs(Coqpit):
|
|||
inference_noise_scale_dp: float = 1.0
|
||||
max_inference_len: int = None
|
||||
init_discriminator: bool = True
|
||||
use_spectral_norm_disriminator: bool = False
|
||||
use_spectral_norm_discriminator: bool = False
|
||||
use_speaker_embedding: bool = False
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
|
@ -857,21 +859,21 @@ class Vits(BaseTTS):
|
|||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||
mas_attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
mas_attn_durations = mas_attn.sum(3)
|
||||
if self.args.use_sdp:
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
mas_attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
attn_log_durations = torch.log(mas_attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
|
@ -880,7 +882,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn
|
||||
return outputs, mas_attn
|
||||
|
||||
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
@ -965,11 +967,11 @@ class Vits(BaseTTS):
|
|||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# duration predictor
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
outputs, mas_attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, logs_p])
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
@ -1005,7 +1007,7 @@ class Vits(BaseTTS):
|
|||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"alignments": mas_attn.squeeze(1),
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"z": z,
|
||||
|
@ -1269,7 +1271,8 @@ class Vits(BaseTTS):
|
|||
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
|
||||
@staticmethod
|
||||
def _log(ap, outputs, name_prefix="train"):
|
||||
y_hat = outputs[1]["model_outputs"]
|
||||
y = outputs[1]["waveform_seg"]
|
||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||
|
@ -1302,7 +1305,7 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
figures, audios = self._log(self.ap, batch, outputs, "train")
|
||||
figures, audios = self._log(self.ap, outputs, "train")
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
@ -1311,7 +1314,7 @@ class Vits(BaseTTS):
|
|||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
||||
figures, audios = self._log(self.ap, outputs, "eval")
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
@ -1542,9 +1545,7 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
# select generator parameters
|
||||
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||
|
||||
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
|
|
Loading…
Reference in New Issue