Rename vars in VITS

This commit is contained in:
Eren Gölge 2022-04-04 09:45:46 +02:00 committed by Eren G??lge
parent 775a6ab6ee
commit 760f045aaa
1 changed files with 16 additions and 15 deletions

View File

@ -29,6 +29,7 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.generic_utils import count_parameters
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -125,6 +126,7 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
@ -521,7 +523,7 @@ class VitsArgs(Coqpit):
inference_noise_scale_dp: float = 1.0
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
use_spectral_norm_discriminator: bool = False
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
@ -857,21 +859,21 @@ class Vits(BaseTTS):
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
mas_attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
# duration predictor
attn_durations = attn.sum(3)
mas_attn_durations = mas_attn.sum(3)
if self.args.use_sdp:
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
mas_attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
attn_log_durations = torch.log(mas_attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
@ -880,7 +882,7 @@ class Vits(BaseTTS):
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
return outputs, attn
return outputs, mas_attn
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
spec_segment_size = self.spec_segment_size
@ -965,11 +967,11 @@ class Vits(BaseTTS):
z_p = self.flow(z, y_mask, g=g)
# duration predictor
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
outputs, mas_attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
m_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
@ -1005,7 +1007,7 @@ class Vits(BaseTTS):
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"alignments": mas_attn.squeeze(1),
"m_p": m_p,
"logs_p": logs_p,
"z": z,
@ -1269,7 +1271,8 @@ class Vits(BaseTTS):
raise ValueError(" [!] Unexpected `optimizer_idx`.")
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
@staticmethod
def _log(ap, outputs, name_prefix="train"):
y_hat = outputs[1]["model_outputs"]
y = outputs[1]["waveform_seg"]
figures = plot_results(y_hat, y, ap, name_prefix)
@ -1302,7 +1305,7 @@ class Vits(BaseTTS):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
figures, audios = self._log(self.ap, batch, outputs, "train")
figures, audios = self._log(self.ap, outputs, "train")
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)
@ -1311,7 +1314,7 @@ class Vits(BaseTTS):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
figures, audios = self._log(self.ap, batch, outputs, "eval")
figures, audios = self._log(self.ap, outputs, "eval")
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@ -1542,9 +1545,7 @@ class Vits(BaseTTS):
Returns:
List: optimizers.
"""
# select generator parameters
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters