Remove the unusable fine-tuning model

This commit is contained in:
Edresson 2021-11-22 08:19:36 -03:00 committed by Eren Gölge
parent 352aa69eca
commit 6fc3b9e679
2 changed files with 10 additions and 159 deletions

View File

@ -599,7 +599,6 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_fake, feats_disc_fake,
feats_disc_real, feats_disc_real,
loss_duration, loss_duration,
fine_tuning_mode=0,
use_speaker_encoder_as_loss=False, use_speaker_encoder_as_loss=False,
gt_spk_emb=None, gt_spk_emb=None,
syn_spk_emb=None, syn_spk_emb=None,
@ -623,14 +622,9 @@ class VitsGeneratorLoss(nn.Module):
# compute mel spectrograms from the waveforms # compute mel spectrograms from the waveforms
mel = self.stft(waveform) mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat) mel_hat = self.stft(waveform_hat)
# compute losses # compute losses
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
# ignore tts model loss if fine tunning mode is on
if fine_tuning_mode:
loss_kl = 0.0
else:
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha

View File

@ -167,11 +167,6 @@ class VitsArgs(Coqpit):
speaker_encoder_model_path (str): speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
fine_tuning_mode (int):
Fine tuning only the vocoder part of the model, while the rest will be frozen. Defaults to 0.
Mode 0: Disabled;
Mode 1: uses the distribution predicted by the encoder and It's recommended for TTS;
Mode 2: uses the distribution predicted by the encoder and It's recommended for voice conversion.
""" """
num_chars: int = 100 num_chars: int = 100
@ -219,7 +214,6 @@ class VitsArgs(Coqpit):
use_speaker_encoder_as_loss: bool = False use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = "" speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = "" speaker_encoder_model_path: str = ""
fine_tuning_mode: int = 0
freeze_encoder: bool = False freeze_encoder: bool = False
freeze_DP: bool = False freeze_DP: bool = False
freeze_PE: bool = False freeze_PE: bool = False
@ -672,122 +666,6 @@ class Vits(BaseTTS):
) )
return outputs return outputs
def forward_fine_tuning(
self,
x: torch.tensor,
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
waveform=None,
) -> Dict:
"""Forward pass of the model.
Args:
x (torch.tensor): Batch of input character sequence IDs.
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
Shapes:
- x: :math:`[B, T_seq]`
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
with torch.no_grad():
outputs = {}
sid, g, lid = self._set_cond_input(aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# mode 1: like SC-GlowTTS paper; mode 2: recommended for voice conversion
if self.args.fine_tuning_mode == 1:
z_ft = m_p
elif self.args.fine_tuning_mode == 2:
z_ft = z_p
else:
raise RuntimeError(" [!] Invalid Fine Tunning Mode !")
# inverse decoder and get the output
z_f_pred = self.flow(z_ft, y_mask, g=g, reverse=True)
z_slice, slice_ids = rand_segments(z_f_pred, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
# resample audio to speaker encoder sample_rate
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
else:
gt_spk_emb, syn_spk_emb = None, None
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"loss_duration": 0.0,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb,
}
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
""" """
@ -869,15 +747,6 @@ class Vits(BaseTTS):
if optimizer_idx not in [0, 1]: if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.") raise ValueError(" [!] Unexpected `optimizer_idx`.")
# generator pass
if self.args.fine_tuning_mode:
# ToDo: find better place fot it
# force eval mode
self.eval()
# restore train mode for the vocoder part
self.waveform_decoder.train()
self.disc.train()
if self.args.freeze_encoder: if self.args.freeze_encoder:
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
@ -913,25 +782,14 @@ class Vits(BaseTTS):
waveform = batch["waveform"] waveform = batch["waveform"]
# generator pass # generator pass
if self.args.fine_tuning_mode: outputs = self.forward(
# model forward text_input,
outputs = self.forward_fine_tuning( text_lengths,
text_input, linear_input.transpose(1, 2),
text_lengths, mel_lengths,
linear_input.transpose(1, 2), aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
mel_lengths, waveform=waveform,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, )
waveform=waveform,
)
else:
outputs = self.forward(
text_input,
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
waveform=waveform,
)
# cache tensors for the discriminator # cache tensors for the discriminator
self.y_disc_cache = None self.y_disc_cache = None
@ -958,7 +816,6 @@ class Vits(BaseTTS):
feats_disc_fake=outputs["feats_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"], feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"], loss_duration=outputs["loss_duration"],
fine_tuning_mode=self.args.fine_tuning_mode,
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=outputs["gt_spk_emb"], gt_spk_emb=outputs["gt_spk_emb"],
syn_spk_emb=outputs["syn_spk_emb"], syn_spk_emb=outputs["syn_spk_emb"],